Merge branch 'main' into save_vram

This commit is contained in:
StAlKeR7779 2023-07-18 16:55:48 +03:00 committed by GitHub
commit 889b77d3d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
125 changed files with 3745 additions and 10444 deletions

View File

@ -132,8 +132,10 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals) ### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are You must have Python 3.9 or 3.10 installed on your machine. Earlier or
not supported. later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)
1. Open a command-line window on your machine. The PowerShell is recommended for Windows. 1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space: 2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
@ -197,11 +199,18 @@ not supported.
7. Launch the web server (do it every time you run InvokeAI): 7. Launch the web server (do it every time you run InvokeAI):
```terminal ```terminal
invokeai --web invokeai-web
``` ```
8. Point your browser to http://localhost:9090 to bring up the web interface. 8. Build Node.js assets
9. Type `banana sushi` in the box on the top left and click `Invoke`.
```terminal
cd invokeai/frontend/web/
yarn vite build
```
9. Point your browser to http://localhost:9090 to bring up the web interface.
10. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI, Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`. using `source .venv/bin/activate` or `.venv\Scripts\activate`.

View File

@ -81,3 +81,193 @@ pytest --cov; open ./coverage/html/index.html
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.--> <!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md" --8<-- "invokeai/frontend/web/README.md"
## Developing InvokeAI in VSCode
VSCode offers some nice tools:
- python debugger
- automatic `venv` activation
- remote dev (e.g. run InvokeAI on a beefy linux desktop while you type in
comfort on your macbook)
### Setup
You'll need the
[Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python)
and
[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance)
extensions installed first.
It's also really handy to install the `Jupyter` extensions:
- [Jupyter](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)
- [Jupyter Cell Tags](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-cell-tags)
- [Jupyter Notebook Renderers](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter-renderers)
- [Jupyter Slide Show](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-slideshow)
#### InvokeAI workspace
Creating a VSCode workspace for working on InvokeAI is highly recommended. It
can hold InvokeAI-specific settings and configs.
To make a workspace:
- Open the InvokeAI repo dir in VSCode
- `File` > `Save Workspace As` > save it _outside_ the repo
#### Default python interpreter (i.e. automatic virtual environment activation)
- Use command palette to run command
`Preferences: Open Workspace Settings (JSON)`
- Add `python.defaultInterpreterPath` to `settings`, pointing to your `venv`'s
python
Should look something like this:
```jsonc
{
// I like to have all InvokeAI-related folders in my workspace
"folders": [
{
// repo root
"path": "InvokeAI"
},
{
// InvokeAI root dir, where `invokeai.yaml` lives
"path": "/path/to/invokeai_root"
}
],
"settings": {
// Where your InvokeAI `venv`'s python executable lives
"python.defaultInterpreterPath": "/path/to/invokeai_root/.venv/bin/python"
}
}
```
Now when you open the VSCode integrated terminal, or do anything that needs to
run python, it will automatically be in your InvokeAI virtual environment.
Bonus: When you create a Jupyter notebook, when you run it, you'll be prompted
for the python interpreter to run in. This will default to your `venv` python,
and so you'll have access to the same python environment as the InvokeAI app.
This is _super_ handy.
#### Debugging configs with `launch.json`
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
these can be scoped to a workspace or folder.
Follow the [official guide](https://code.visualstudio.com/docs/python/debugging)
to set up your `launch.json` and try it out.
Now we can create the InvokeAI debugging configs:
```jsonc
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
// Run the InvokeAI backend & serve the pre-built UI
"name": "InvokeAI Web",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-web.py",
"args": [
// Your InvokeAI root dir (where `invokeai.yaml` lives)
"--root",
"/path/to/invokeai_root",
// Access the app from anywhere on your local network
"--host",
"0.0.0.0"
],
"justMyCode": true
},
{
// Run the nodes-based CLI
"name": "InvokeAI CLI",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-cli.py",
"justMyCode": true
},
{
// Run tests
"name": "InvokeAI Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["--capture=no"],
"justMyCode": true
},
{
// Run a single test
"name": "InvokeAI Single Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": [
// Change this to point to the specific test you are working on
"tests/nodes/test_invoker.py"
],
"justMyCode": true
},
{
// This is the default, useful to just run a single file
"name": "Python: File",
"type": "python",
"request": "launch",
"program": "${file}",
"justMyCode": true
}
]
}
```
You'll see these configs in the debugging configs drop down. Running them will
start InvokeAI with attached debugger, in the correct environment, and work just
like the normal app.
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).
#### Remote dev
This is very easy to set up and provides the same very smooth experience as
local development. Environments and debugging, as set up above, just work,
though you'd need to recreate the workspace and debugging configs on the remote.
Consult the
[official guide](https://code.visualstudio.com/docs/remote/remote-overview) to
get it set up.
Suggest using VSCode's included settings sync so that your remote dev host has
all the same app settings and extensions automagically.
##### One remote dev gotcha
I've found the automatic port forwarding to be very flakey. You can disable it
in `Preferences: Open Remote Settings (ssh: hostname)`. Search for
`remote.autoForwardPorts` and untick the box.
To forward ports very reliably, use SSH on the remote dev client (e.g. your
macbook). Here's how to forward both backend API port (`9090`) and the frontend
live dev server port (`5173`):
```bash
ssh \
-L 9090:localhost:9090 \
-L 5173:localhost:5173 \
user@remote-dev-host
```
The forwarding stops when you close the terminal window, so suggest to do this
_outside_ the VSCode integrated terminal in case you need to restart VSCode for
an extension update or something
Now, on your remote dev client, you can open `localhost:9090` and access the UI,
now served from the remote dev host, just the same as if it was running on the
client.

View File

@ -1,4 +1,8 @@
# Nodes Editor (Experimental Beta) # Nodes Editor (Experimental)
🚨
*The node editor is experimental. We've made it accessible because we use it to develop the application, but we have not addressed the many known rough edges. It's very easy to shoot yourself in the foot, and we cannot offer support for it until it sees full release (ETA v3.1). Everything is subject to change without warning.*
🚨
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node. The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.

View File

@ -24,7 +24,8 @@ read -e -p "Tag this repo with '${VERSION}' and '${LATEST_TAG}'? [n]: " input
RESPONSE=${input:='n'} RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then if [ "$RESPONSE" == 'y' ]; then
if ! git tag $VERSION ; then git push origin :refs/tags/$VERSION
if ! git tag -fa $VERSION ; then
echo "Existing/invalid tag" echo "Existing/invalid tag"
exit -1 exit -1
fi fi

View File

@ -38,7 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
echo. echo.
echo See %INSTRUCTIONS% for more details. echo See %INSTRUCTIONS% for more details.
echo. echo.
echo "For the best user experience we suggest enlarging or maximizing this window now." echo FOR THE BEST USER EXPERIENCE WE SUGGEST MAXIMIZING THIS WINDOW NOW.
pause pause
@rem ---------------------------- check Python version --------------- @rem ---------------------------- check Python version ---------------

View File

@ -19,7 +19,7 @@ echo 8. Open the developer console
echo 9. Update InvokeAI echo 9. Update InvokeAI
echo 10. Command-line help echo 10. Command-line help
echo Q - Quit echo Q - Quit
set /P choice="Please enter 1-10, Q: [2] " set /P choice="Please enter 1-10, Q: [1] "
if not defined choice set choice=1 if not defined choice set choice=1
IF /I "%choice%" == "1" ( IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI.. echo Starting the InvokeAI browser-based UI..

View File

@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
) )
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
@ -20,7 +21,6 @@ from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_queue import MemoryInvocationQueue
@ -57,8 +57,8 @@ class ApiDependencies:
invoker: Invoker = None invoker: Invoker = None
@staticmethod @staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger): def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.debug(f'InvokeAI version {__version__}') logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}") logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id)
@ -117,7 +117,7 @@ class ApiDependencies:
) )
services = InvocationServices( services = InvocationServices(
model_manager=ModelManagerService(config,logger), model_manager=ModelManagerService(config, logger),
events=events, events=events,
latents=latents, latents=latents,
images=images, images=images,
@ -129,7 +129,6 @@ class ApiDependencies:
), ),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config, configuration=config,
logger=logger, logger=logger,
) )

View File

@ -13,8 +13,11 @@ from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import ( from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType, SchedulerPredictionType,
ModelNotFoundException,
InvalidModelException,
) )
from invokeai.backend.model_management import MergeInterpolationMethod from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -51,8 +54,9 @@ async def list_models(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"}, responses={200: {"description" : "The model was updated successfully"},
400: {"description" : "Bad request"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"} 409: {"description" : "There is already a model corresponding to the new name"},
}, },
status_code = 200, status_code = 200,
response_model = UpdateModelResponse, response_model = UpdateModelResponse,
@ -63,23 +67,58 @@ async def update_model(
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse: ) -> UpdateModelResponse:
""" Add Model """ """ Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger
try: try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = info.model_name,
new_base = info.base_model,
)
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get('path')
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
model_attributes=info.dict() model_attributes=info.dict()
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
) )
model_response = parse_obj_as(UpdateModelResponse, model_raw) model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return model_response return model_response
@ -90,6 +129,7 @@ async def update_model(
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
415: {"description" : "Unrecognized file/folder format"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"}, 409: {"description" : "There is already a model corresponding to this path or repo_id"},
}, },
@ -116,7 +156,7 @@ async def import_model(
if not info: if not info:
logger.error("Import failed") logger.error("Import failed")
raise HTTPException(status_code=424) raise HTTPException(status_code=415)
logger.info(f'Successfully imported {location}, got {info}') logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
@ -126,9 +166,12 @@ async def import_model(
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@ -166,57 +209,13 @@ async def add_model(
model_type=info.model_type model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
@ -243,9 +242,9 @@ async def delete_model(
) )
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
return Response(status_code=204) return Response(status_code=204)
except KeyError: except ModelNotFoundException as e:
logger.error(f"Model not found: {model_name}") logger.error(str(e))
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=str(e))
@models_router.put( @models_router.put(
"/convert/{base_model}/{model_type}/{model_name}", "/convert/{base_model}/{model_type}/{model_name}",
@ -278,8 +277,8 @@ async def convert_model(
base_model = base_model, base_model = base_model,
model_type = model_type) model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@ -369,8 +368,55 @@ async def merge_models(
model_type = ModelType.Main, model_type = ModelType.Main,
) )
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
# The rename operation is now supported by update_model and no longer needs to be
# a standalone route.
# @models_router.post(
# "/rename/{base_model}/{model_type}/{model_name}",
# operation_id="rename_model",
# responses= {
# 201: {"description" : "The model was renamed successfully"},
# 404: {"description" : "The model could not be found"},
# 409: {"description" : "There is already a model corresponding to the new name"},
# },
# status_code=201,
# response_model=ImportModelResponse
# )
# async def rename_model(
# base_model: BaseModelType = Path(description="Base model"),
# model_type: ModelType = Path(description="The type of model"),
# model_name: str = Path(description="current model name"),
# new_name: Optional[str] = Query(description="new model name", default=None),
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
# ) -> ImportModelResponse:
# """ Rename a model"""
# logger = ApiDependencies.invoker.services.logger
# try:
# result = ApiDependencies.invoker.services.model_manager.rename_model(
# base_model = base_model,
# model_type = model_type,
# model_name = model_name,
# new_name = new_name,
# new_base = new_base,
# )
# logger.debug(result)
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
# model_name=new_name or model_name,
# base_model=new_base or base_model,
# model_type=model_type
# )
# return parse_obj_as(ImportModelResponse, model_raw)
# except ModelNotFoundException as e:
# logger.error(str(e))
# raise HTTPException(status_code=404, detail=str(e))
# except ValueError as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))

View File

@ -39,6 +39,7 @@ from .invocations.baseinvocation import BaseInvocation
import torch import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes

View File

@ -54,10 +54,10 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
import torch import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes
@ -295,7 +295,6 @@ def invoke_cli():
), ),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger, logger=logger,
configuration=config, configuration=config,
) )

View File

@ -86,10 +86,10 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
@ -111,6 +111,7 @@ class CompelInvocation(BaseInvocation):
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:
@ -129,7 +130,7 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False, truncate_long_prompts=True,
) )
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
@ -438,7 +439,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
) )
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Pass unmodified prompt to conditioning without compel processing."""
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt" type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"

View File

@ -161,13 +161,13 @@ class InpaintInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context,)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
with vae_info as vae,\ with vae_info as vae,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\

View File

@ -83,7 +83,7 @@ def get_scheduler(
scheduler_name, SCHEDULER_MAP['ddim'] scheduler_name, SCHEDULER_MAP['ddim']
) )
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict() **scheduler_info.dict(), context=context,
) )
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -270,6 +270,7 @@ class TextToLatentsInvocation(BaseInvocation):
model_name=control_info.control_model.model_name, model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet, model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model, base_model=control_info.control_model.base_model,
context=context,
) )
) )
@ -321,14 +322,14 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -414,14 +415,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -506,7 +507,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
with vae_info as vae: with vae_info as vae:
@ -687,7 +688,7 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))

View File

@ -1,6 +1,8 @@
from typing import Literal from os.path import exists
from typing import Literal, Optional
from pydantic.fields import Field import numpy as np
from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
@ -55,3 +57,41 @@ class DynamicPromptInvocation(BaseInvocation):
prompts = generator.generate(self.prompt, num_images=self.max_prompts) prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
class PromptsFromFileInvocation(BaseInvocation):
'''Loads prompts from a text file'''
# fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs
file_path: str = Field(description="Path to prompt text file")
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
#fmt: on
@validator("file_path")
def file_path_exists(cls, v):
if not exists(v):
raise ValueError(FileNotFoundError)
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
if max_prompts <= 0:
end_line = np.iinfo(np.int32).max
with open(file_path) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
if i >= end_line:
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -1,55 +0,0 @@
from typing import Literal, Optional
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
strength=self.strength, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,48 +1,112 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath
from typing import Literal, Optional from typing import Literal, Union, cast
import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field from pydantic import Field
from realesrgan import RealESRGANer
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput from .image import ImageOutput
# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off class RealESRGANInvocation(BaseInvocation):
type: Literal["upscale"] = "upscale" """Upscales an image using RealESRGAN."""
# Inputs type: Literal["realesrgan"] = "realesrgan"
image: Optional[ImageField] = Field(description="The input image", default=None) image: Union[ImageField, None] = Field(default=None, description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength") model_name: REALESRGAN_MODELS = Field(
level: Literal[2, 4] = Field(default=2, description="The upscale level") default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
# fmt: on )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct( models_path = context.services.configuration.models_path
image_list=[[image, 0]],
upscale=(self.level, self.strength), rrdbnet_model = None
strength=0.0, # GFPGAN strength netscale = None
save_original=False, esrgan_model_path = None
image_callback=None,
if self.model_name in [
"RealESRGAN_x4plus.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]:
# x4 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)
raise ValueError(msg)
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
) )
# Results are image and seed, unwrap for now # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: can this return multiple results? cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL
pil_image = Image.fromarray(
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=pil_image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,

View File

@ -271,13 +271,13 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(self)->List[str]: def _excluded(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed # internal fields that shouldn't be exposed as command line options
return ['type','initconf'] return ['type','initconf']
@classmethod @classmethod
def _excluded_from_yaml(self)->List[str]: def _excluded_from_yaml(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed # combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model'] return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']
class Config: class Config:
env_file_encoding = 'utf-8' env_file_encoding = 'utf-8'
@ -366,7 +366,7 @@ setting environment variables INVOKEAI_<setting>.
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features') nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features') restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')

View File

@ -105,8 +105,6 @@ class EventServiceBase:
def emit_model_load_started ( def emit_model_load_started (
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -117,8 +115,6 @@ class EventServiceBase:
event_name="model_load_started", event_name="model_load_started",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -129,8 +125,6 @@ class EventServiceBase:
def emit_model_load_completed( def emit_model_load_completed(
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -142,12 +136,12 @@ class EventServiceBase:
event_name="model_load_completed", event_name="model_load_completed",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, hash=model_info.hash,
location=model_info.location,
precision=str(model_info.precision),
), ),
) )

View File

@ -10,10 +10,9 @@ if TYPE_CHECKING:
from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC from invokeai.app.services.invoker import InvocationProcessorABC
@ -24,7 +23,7 @@ class InvocationServices:
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC" boards: "BoardServiceABC"
configuration: "InvokeAISettings" configuration: "InvokeAIAppConfig"
events: "EventServiceBase" events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"] graph_library: "ItemStorageABC"["LibraryGraph"]
@ -34,13 +33,12 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__( def __init__(
self, self,
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC", boards: "BoardServiceABC",
configuration: "InvokeAISettings", configuration: "InvokeAIAppConfig",
events: "EventServiceBase", events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"], graph_library: "ItemStorageABC"["LibraryGraph"],
@ -50,7 +48,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
restoration: "RestorationServices",
): ):
self.board_images = board_images self.board_images = board_images
self.boards = boards self.boards = boards
@ -65,4 +62,3 @@ class InvocationServices:
self.model_manager = model_manager self.model_manager = model_manager
self.processor = processor self.processor = processor
self.queue = queue self.queue = queue
self.restoration = restoration

View File

@ -18,6 +18,7 @@ from invokeai.backend.model_management import (
SchedulerPredictionType, SchedulerPredictionType,
ModelMerger, ModelMerger,
MergeInterpolationMethod, MergeInterpolationMethod,
ModelNotFoundException,
) )
from invokeai.backend.model_management.model_search import FindModels from invokeai.backend.model_management.model_search import FindModels
@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC):
) -> AddModelResult: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with a Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist. ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
@ -338,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> ModelInfo: ) -> ModelInfo:
""" """
@ -346,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# if we are called from within a node, then we get to emit # we can emit model loading events if we are executing with access to the invocation context
# load start and complete events if context:
if node and context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -365,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
submodel, submodel,
) )
if node and context: if context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -451,14 +448,14 @@ class ModelManagerService(ModelManagerServiceBase):
) -> AddModelResult: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with a Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist. ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'update model {model_name}') self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type): if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}") raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model( def del_model(
@ -509,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event( def _emit_load_event(
self, self,
node,
context, context,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: SubModelType, submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None, model_info: Optional[ModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException() raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info: if model_info:
context.services.events.emit_model_load_completed( context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -535,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
else: else:
context.services.events.emit_model_load_started( context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,

View File

@ -1,113 +0,0 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args,logger:types.ModuleType):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
# TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")
else:
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
# TO DO: refactor into separate methods
def upscale_and_reconstruct(
self,
image_list,
facetool="gfpgan",
upscale=None,
upscale_denoise_str=0.75,
strength=0.0,
codeformer_fidelity=0.75,
save_original=False,
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
CPU_DEVICE if self.device == MPS_DEVICE else self.device
)
image = self.codeformer.process(
image=image,
strength=strength,
device=cf_device,
seed=seed,
fidelity=codeformer_fidelity,
)
else:
self.logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image,
upscale[1],
seed,
int(upscale[0]),
denoise_str=upscale_denoise_str,
)
else:
self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
image_callback(image, seed, upscaled=True, use_prefix=prefix)
else:
r[0] = image
results.append([image, seed])
return results

View File

@ -30,8 +30,6 @@ from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
AutoProcessor,
CLIPSegForImageSegmentation,
CLIPTextModel, CLIPTextModel,
CLIPTokenizer, CLIPTokenizer,
AutoFeatureExtractor, AutoFeatureExtractor,
@ -45,7 +43,6 @@ from invokeai.app.services.config import (
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress, CenteredButtonPress,
IntTitleSlider, IntTitleSlider,
set_min_terminal_size, set_min_terminal_size,
@ -72,7 +69,6 @@ transformers.logging.set_verbosity_error()
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
Model_dir = "models" Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
Default_config_file = config.model_conf_path Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path SD_Configs = config.legacy_conf_path
@ -226,64 +222,30 @@ def download_conversion_models():
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
logger.info("Installing models from RealESRGAN...") logger.info("Installing RealESRGAN models...")
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" URLs = [
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth" dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth" dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth" description = "RealESRGAN_x4plus.pth",
),
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN") dict(
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn") url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description = "RealESRGAN_x4plus_anime_6B.pth",
def download_gfpgan(): ),
logger.info("Installing GFPGAN models...") dict(
for model in ( url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
[ dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", description = "ESRGAN_SRx4_DF2KOST_official.pth",
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth", ),
], ]
[ for model in URLs:
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], config.root_path / model[1]
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
# --------------------------------------------- # ---------------------------------------------
def download_codeformer():
logger.info("Installing CodeFormer model file...")
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
# ---------------------------------------------
def download_clipseg():
logger.info("Installing clipseg model for text-based masking...")
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try:
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
except Exception:
logger.info("Error installing clipseg model:")
logger.info(traceback.format_exc())
def download_support_models(): def download_support_models():
download_realesrgan() download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_conversion_models() download_conversion_models()
# ------------------------------------- # -------------------------------------
@ -666,7 +628,7 @@ def run_console_ui(
# The third argument is needed in the Windows 11 environment to # The third argument is needed in the Windows 11 environment to
# launch a console window running this program. # launch a console window running this program.
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure') set_min_terminal_size(MIN_COLS, MIN_LINES)
# the install-models application spawns a subprocess to install # the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running. # models, and will crash unless this is set before running.
@ -743,7 +705,7 @@ def migrate_if_needed(opt: Namespace, root: Path)->bool:
old_init_file = root / 'invokeai.init' old_init_file = root / 'invokeai.init'
new_init_file = root / 'invokeai.yaml' new_init_file = root / 'invokeai.yaml'
old_hub = root / 'models/hub' old_hub = root / 'models/hub'
migration_needed = old_init_file.exists() and not new_init_file.exists() or old_hub.exists() migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
if migration_needed: if migration_needed:
if opt.yes_to_all or \ if opt.yes_to_all or \
@ -858,9 +820,9 @@ def main():
download_support_models() download_support_models()
if opt.skip_sd_weights: if opt.skip_sd_weights:
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **") logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
elif models_to_download: elif models_to_download:
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **") logger.info("DOWNLOADING DIFFUSION WEIGHTS")
process_and_execute(opt, models_to_download) process_and_execute(opt, models_to_download)
postscript(errors=errors) postscript(errors=errors)

View File

@ -117,6 +117,7 @@ class ModelInstall(object):
# supplement with entries in models.yaml # supplement with entries in models.yaml
installed_models = self.mgr.list_models() installed_models = self.mgr.list_models()
for md in installed_models: for md in installed_models:
base = md['base_model'] base = md['base_model']
model_type = md['model_type'] model_type = md['model_type']
@ -134,6 +135,12 @@ class ModelInstall(object):
) )
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())} return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print(f'Installed models of type `{model_type}`:')
for i in installed:
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
def starter_models(self)->Set[str]: def starter_models(self)->Set[str]:
models = set() models = set()
for key, value in self.datasets.items(): for key, value in self.datasets.items():
@ -205,7 +212,7 @@ class ModelInstall(object):
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
] ]
): ):
models_installed.update(self._install_path(path)) models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan # recursive scan
elif path.is_dir(): elif path.is_dir():

View File

@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
from .model_merge import ModelMerger, MergeInterpolationMethod from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -552,7 +552,7 @@ class ModelManager(object):
model_config = self.models.get(model_key) model_config = self.models.get(model_key)
if not model_config: if not model_config:
self.logger.error(f'Unknown model {model_name}') self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}') raise ModelNotFoundException(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model: if base_model is not None and cur_base_model != base_model:
@ -568,6 +568,9 @@ class ModelManager(object):
model_type=cur_model_type, model_type=cur_model_type,
) )
# expose paths as absolute to help web UI
if path := model_dict.get('path'):
model_dict['path'] = str(self.app_config.root_path / path)
models.append(model_dict) models.append(model_dict)
return models return models
@ -596,7 +599,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None) model_cfg = self.models.pop(model_key, None)
if model_cfg is None: if model_cfg is None:
raise KeyError(f"Unknown model {model_key}") raise ModelNotFoundException(f"Unknown model {model_key}")
# note: it not garantie to release memory(model can has other references) # note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
@ -635,6 +638,10 @@ class ModelManager(object):
The returned dict has the same format as the dict returned by The returned dict has the same format as the dict returned by
model_info(). model_info().
""" """
# relativize paths as they go in - this makes it easier to move the root directory around
if path := model_attributes.get('path'):
if Path(path).is_relative_to(self.app_config.root_path):
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
@ -689,7 +696,7 @@ class ModelManager(object):
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None) model_cfg = self.models.get(model_key, None)
if not model_cfg: if not model_cfg:
raise KeyError(f"Unknown model: {model_key}") raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name new_name = new_name or model_name
@ -700,7 +707,7 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it # if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path): if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
move(old_path, new_path) move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
@ -908,7 +915,6 @@ class ModelManager(object):
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch): class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger) super().__init__(directories, logger)
@ -965,7 +971,7 @@ class ModelManager(object):
that model. that model.
May return the following exceptions: May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists - ValueError - a corresponding model already exists
''' '''
# avoid circular import here # avoid circular import here

View File

@ -12,6 +12,7 @@ from picklescan.scanner import scan_file_path
from .models import ( from .models import (
BaseModelType, ModelType, ModelVariantType, BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings, SchedulerPredictionType, SilenceWarnings,
InvalidModelException
) )
from .models.base import read_checkpoint_meta from .models.base import read_checkpoint_meta
@ -61,7 +62,7 @@ class ModelProbe(object):
elif isinstance(model,(dict,ModelMixin,ConfigMixin)): elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper) return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else: else:
raise ValueError("model parameter {model} is neither a Path, nor a model") raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod @classmethod
def probe(cls, def probe(cls,
@ -141,7 +142,7 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion return ModelType.TextualInversion
raise ValueError(f"Unable to determine model type for {model_path}") raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod @classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
@ -171,7 +172,7 @@ class ModelProbe(object):
return type return type
# give up # give up
raise ValueError(f"Unable to determine model type for {folder_path}") raise InvalidModelException(f"Unable to determine model type for {folder_path}")
@classmethod @classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict: def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
@ -240,7 +241,7 @@ class CheckpointProbeBase(ProbeBase):
elif in_channels == 4: elif in_channels == 4:
return ModelVariantType.Normal return ModelVariantType.Normal
else: else:
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}") raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
class PipelineCheckpointProbe(CheckpointProbeBase): class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
@ -254,7 +255,7 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
# TODO: Verify that this is correct! Need an XL checkpoint file for this. # TODO: Verify that this is correct! Need an XL checkpoint file for this.
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL return BaseModelType.StableDiffusionXL
raise ValueError("Cannot determine base type") raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType: def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type() type = self.get_base_type()
@ -335,7 +336,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper: elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path) return self.helper(self.checkpoint_path)
raise ValueError("Unable to determine base type for {self.checkpoint_path}") raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
######################################################## ########################################################
# classes for probing folders # classes for probing folders
@ -371,7 +372,7 @@ class PipelineFolderProbe(FolderProbeBase):
elif unet_conf['cross_attention_dim'] == 2048: elif unet_conf['cross_attention_dim'] == 2048:
return BaseModelType.StableDiffusionXL return BaseModelType.StableDiffusionXL
else: else:
raise ValueError(f'Unknown base model for {self.folder_path}') raise InvalidModelException(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType: def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model: if self.model:
@ -428,7 +429,7 @@ class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'config.json' config_file = self.folder_path / 'config.json'
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Cannot determine base type for {self.folder_path}") raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file,'r') as file: with open(config_file,'r') as file:
config = json.load(file) config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768 # no obvious way to distinguish between sd2-base and sd2-768
@ -445,7 +446,7 @@ class LoRAFolderProbe(FolderProbeBase):
model_file = base_file model_file = base_file
break break
if not model_file: if not model_file:
raise ValueError('Unknown LoRA format encountered') raise InvalidModelException('Unknown LoRA format encountered')
return LoRACheckpointProbe(model_file,None).get_base_type() return LoRACheckpointProbe(model_file,None).get_base_type()
############## register probe classes ###### ############## register probe classes ######

View File

@ -68,7 +68,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti return None # diffusers-ti
if os.path.isfile(path): if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
return None return None
raise InvalidModelException(f"Not a valid model: {path}") raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -16,6 +16,7 @@ from .base import (
calc_model_size_by_data, calc_model_size_by_data,
classproperty, classproperty,
InvalidModelException, InvalidModelException,
ModelNotFoundException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available

View File

@ -1,4 +0,0 @@
"""
Initialization file for the invokeai.backend.restoration package
"""
from .base import Restoration

View File

@ -1,45 +0,0 @@
import invokeai.backend.util.logging as logger
class Restoration:
def __init__(self) -> None:
pass
def load_face_restore_models(
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
):
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
logger.info("GFPGAN Initialized")
else:
logger.info("GFPGAN Disabled")
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
logger.info("CodeFormer Initialized")
else:
logger.info("CodeFormer Disabled")
codeformer = None
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self, gfpgan_model_path):
from .gfpgan import GFPGAN
return GFPGAN(gfpgan_model_path)
def load_codeformer(self):
from .codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_esrgan(self, esrgan_bg_tile=400):
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
logger.info("ESRGAN Initialized")
return esrgan

View File

@ -1,120 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
class CodeFormerRestoration:
def __init__(
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists()
if not self.codeformer_model_exists:
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image
from torchvision.transforms.functional import normalize
from .codeformer_arch import CodeFormer
cf_class = CodeFormer
cf = cf_class(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
# note that this file should already be downloaded and cached at
# this point
checkpoint_path = load_file_from_url(
url=pretrained_model_url,
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
progress=True,
)
checkpoint = torch.load(checkpoint_path)["params_ema"]
cf.load_state_dict(checkpoint)
cf.eval()
image = image.convert("RGB")
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
face_helper = FaceRestoreHelper(
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
)
del output
torch.cuda.empty_cache()
except RuntimeError as error:
logger.error(f"Failed inference for CodeFormer: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
cf = None
return res

View File

@ -1,325 +0,0 @@
import math
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
from torch import Tensor, nn
from .vqgan_arch import *
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, "The input feature should be 4D tensor."
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
size
)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(
self,
dim_embd=512,
n_head=8,
n_layers=9,
codebook_size=1024,
latent_size=256,
connect_list=["32", "64", "128", "256"],
fix_modules=["quantize", "generator"],
):
super(CodeFormer, self).__init__(
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(
*[
TransformerSALayer(
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
)
for _ in range(self.n_layers)
]
)
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
)
self.channels = {
"16": 512,
"32": 256,
"64": 256,
"128": 128,
"256": 128,
"512": 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {
"512": 2,
"256": 5,
"128": 8,
"64": 11,
"32": 14,
"16": 18,
}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {
"16": 6,
"32": 9,
"64": 12,
"128": 15,
"256": 18,
"512": 21,
}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(
top_idx, shape=[x.shape[0], 16, 16, 256]
)
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w > 0:
x = self.fuse_convs_dict[f_size](
enc_feat_dict[f_size].detach(), x, w
)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View File

@ -1,84 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
return None
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd()
os.chdir(self.globals.root_dir / 'models')
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch="clean",
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
logger.error("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
logger.warning("WARNING: GFPGAN not initialized.")
logger.warning(
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
image = image.convert("RGB")
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
_, _, restored_img = self.gfpgan.enhance(
bgr_image_array,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@ -1,118 +0,0 @@
import math
from PIL import Image
import invokeai.backend.util.logging as logger
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process(
self,
extents: dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback=None,
prefix=None,
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_scheduler()
def wrapped_callback(img, seed, **kwargs):
preferred_seed = (
orig_opt.seed
if orig_opt.seed is not None and orig_opt.seed >= 0
else seed
)
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
result = self.generate.prompt2image(
opt.prompt,
seed=opt.seed or orig_opt.seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=extended_image.width,
height=extended_image.height,
init_img=extended_image,
strength=0.90,
image_callback=wrapped_callback if image_callback else None,
seam_size=opt.seam_size or 96,
seam_blur=opt.seam_blur or 16,
seam_strength=opt.seam_strength or 0.7,
seam_steps=20,
tile_size=32,
color_match=True,
force_outpaint=True, # this just stops the warning about erased regions
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents: dict,
) -> Image:
"""
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
"""
image = self.image
for direction in extents:
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
logger.info(f"extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)
return image
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
"""
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
"""
transposes = {
"right": ["ROTATE_90", "ROTATE_270"],
"bottom": ["ROTATE_180", "ROTATE_180"],
"left": ["ROTATE_270", "ROTATE_90"],
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self, image: Image, pixels: int) -> Image:
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
extended_img.paste(image, box=(0, pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel("A")
alpha.paste(0, (0, 0, extended_img.width, pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@ -1,102 +0,0 @@
import math
import warnings
from PIL import Image, ImageFilter
class Outpaint(object):
def __init__(self, image, generate):
self.image = image
self.generate = generate
def process(self, opt, old_opt, image_callback=None, prefix=None):
image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed
prompt = old_opt.prompt
def wrapped_callback(img, seed, **kwargs):
image_callback(img, seed, use_prefix=prefix, **kwargs)
return self.generate.prompt2image(
prompt,
seed=seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=opt.width,
height=opt.height,
init_img=image,
strength=0.83,
image_callback=wrapped_callback,
prefix=prefix,
)
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [
1,
2,
], "Direction (-D) must have exactly one or two arguments."
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == "left":
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == "bottom":
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height // 2 if pixels is None else int(pixels)
assert (
0 < pixels < image.height
), "Direction (-D) pixels length must be in the range 0 - image.size"
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height // 10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == "left":
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == "bottom":
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img

View File

@ -1,104 +0,0 @@
import warnings
import numpy as np
import torch
from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
scale = 4
bg_upsampler = RealESRGANer(
scale=scale,
model_path=[model_path, wdn_model_path],
model=model,
tile=self.bg_tile_size,
dni_weight=[denoise_str, 1 - denoise_str],
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(
self,
image: ImageType,
strength: float,
seed: str = None,
upsampler_scale: int = 2,
denoise_str: float = 0.75,
):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
except Exception:
import sys
import traceback
logger.error("Error loading Real-ESRGAN:")
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image
if seed is not None:
logger.info(
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
output, _ = upsampler.enhance(
bgr_image_array,
outscale=upsampler_scale,
alpha_upsampler="realesrgan",
)
# Flip the channels back to RGB
res = Image.fromarray(output[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -1,514 +0,0 @@
"""
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
"""
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
@torch.jit.script
def swish(x):
return x * torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(
-1.0 / self.codebook_size, 1.0 / self.codebook_size
)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
(z_flattened**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight**2).sum(1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(
d, 1, dim=1, largest=False
)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.codebook_size
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return (
z_q,
loss,
{
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance,
},
)
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1, 1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(
self,
codebook_size,
emb_dim,
num_hiddens,
straight_through=False,
kl_weight=5e-4,
temp_init=1.0,
):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(
num_hiddens, codebook_size, 1
) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
)
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
in_channels,
nf,
emb_dim,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
blocks = []
# initial conv
blocks.append(
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
)
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(
self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=[16],
codebook_size=1024,
emb_dim=256,
beta=0.25,
gumbel_straight_through=False,
gumbel_kl_weight=1e-8,
model_path=None,
):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if self.quantizer_type == "nearest":
self.beta = beta # 0.25
self.quantize = VectorQuantizer(
self.codebook_size, self.embed_dim, self.beta
)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight,
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_ema" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_ema"]
)
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, True),
]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n_layers, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_d" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_d"]
)
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
return self.main(x)

View File

@ -221,7 +221,7 @@ class ControlNetData:
control_mode: str = Field(default="balanced") control_mode: str = Field(default="balanced")
@dataclass(frozen=True) @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: torch.Tensor unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor text_embeddings: torch.Tensor
@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu') scheduler_device = torch.device('cpu')
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps, timesteps,
conditioning_data, conditioning_data,
noise=noise, noise=noise,
additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
**kwargs,
callback=callback,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -505,7 +502,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
@ -546,7 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
**kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -588,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -632,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = conditioning_data.text_embeddings encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else: else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.text_embeddings]) conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index] controlnet_weight = control_datum.weight[step_index]
@ -649,6 +646,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False, return_dict=False,
) )

View File

@ -237,11 +237,7 @@ class InvokeAIDiffuserComponent:
) )
return latents return latents
# methods below are called from do_diffusion_step and should be considered private to this class. def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
def _pad_conditioning(cond, target_len, encoder_attention_mask): def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
@ -266,16 +262,24 @@ class InvokeAIDiffuserComponent:
return cond, encoder_attention_mask return cond, encoder_attention_mask
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
encoder_attention_mask = None encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]: if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1]) max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
both_conditionings = torch.cat([unconditioning, conditioning]) return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
unconditioning, conditioning
)
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
@ -293,8 +297,32 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
): ):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) uncond_down_block, cond_down_block = None, None
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
# TODO: looks unused # TODO: looks unused
@ -328,6 +356,20 @@ class InvokeAIDiffuserComponent:
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext( cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning, modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map, index_map=context.cross_attention_index_map,
@ -340,6 +382,8 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs, **kwargs,
) )
@ -352,6 +396,8 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs, **kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x

View File

@ -0,0 +1,634 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel

View File

@ -64,22 +64,29 @@ sd-1/main/waifu-diffusion:
recommended: False recommended: False
sd-1/controlnet/canny: sd-1/controlnet/canny:
repo_id: lllyasviel/control_v11p_sd15_canny repo_id: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint: sd-1/controlnet/inpaint:
repo_id: lllyasviel/control_v11p_sd15_inpaint repo_id: lllyasviel/control_v11p_sd15_inpaint
sd-1/controlnet/mlsd: sd-1/controlnet/mlsd:
repo_id: lllyasviel/control_v11p_sd15_mlsd repo_id: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth: sd-1/controlnet/depth:
repo_id: lllyasviel/control_v11f1p_sd15_depth repo_id: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae: sd-1/controlnet/normal_bae:
repo_id: lllyasviel/control_v11p_sd15_normalbae repo_id: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg: sd-1/controlnet/seg:
repo_id: lllyasviel/control_v11p_sd15_seg repo_id: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart: sd-1/controlnet/lineart:
repo_id: lllyasviel/control_v11p_sd15_lineart repo_id: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime: sd-1/controlnet/lineart_anime:
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
repo_id: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble: sd-1/controlnet/scribble:
repo_id: lllyasviel/control_v11p_sd15_scribble repo_id: lllyasviel/control_v11p_sd15_scribble
recommended: False
sd-1/controlnet/softedge: sd-1/controlnet/softedge:
repo_id: lllyasviel/control_v11p_sd15_softedge repo_id: lllyasviel/control_v11p_sd15_softedge
sd-1/controlnet/shuffle: sd-1/controlnet/shuffle:
@ -90,9 +97,11 @@ sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p repo_id: lllyasviel/control_v11e_sd15_ip2p
sd-1/embedding/EasyNegative: sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True
sd-1/embedding/ahx-beta-453407d: sd-1/embedding/ahx-beta-453407d:
repo_id: sd-concepts-library/ahx-beta-453407d repo_id: sd-concepts-library/ahx-beta-453407d
sd-1/lora/LowRA: sd-1/lora/LowRA:
path: https://civitai.com/api/download/models/63006 path: https://civitai.com/api/download/models/63006
recommended: True
sd-1/lora/Ink scenery: sd-1/lora/Ink scenery:
path: https://civitai.com/api/download/models/83390 path: https://civitai.com/api/download/models/83390

View File

@ -256,6 +256,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
widgets = dict() widgets = dict()
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude] model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
model_labels = [self.model_labels[x] for x in model_list] model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models)==0
if len(model_list) > 0: if len(model_list) > 0:
max_width = max([len(x) for x in model_labels]) max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
@ -280,7 +282,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
value=[ value=[
model_list.index(x) model_list.index(x)
for x in model_list for x in model_list
if self.all_models[x].installed if (show_recommended and self.all_models[x].recommended) \
or self.all_models[x].installed
], ],
max_height=len(model_list)//columns + 1, max_height=len(model_list)//columns + 1,
relx=4, relx=4,
@ -672,7 +675,9 @@ def select_and_download_models(opt: Namespace):
# pass # pass
installer = ModelInstall(config, prediction_type_helper=helper) installer = ModelInstall(config, prediction_type_helper=helper)
if opt.add or opt.delete: if opt.list_models:
installer.list_models(opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections( selections = InstallSelections(
install_models = opt.add or [], install_models = opt.add or [],
remove_models = opt.delete or [] remove_models = opt.delete or []
@ -696,7 +701,7 @@ def select_and_download_models(opt: Namespace):
# the third argument is needed in the Windows 11 environment in # the third argument is needed in the Windows 11 environment in
# order to launch and resize a console window running this program # order to launch and resize a console window running this program
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-model-install') set_min_terminal_size(MIN_COLS, MIN_LINES)
installApp = AddModelApplication(opt) installApp = AddModelApplication(opt)
try: try:
installApp.run() installApp.run()
@ -745,7 +750,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--list-models", "--list-models",
choices=["diffusers","loras","controlnets","tis"], choices=[x.value for x in ModelType],
help="list installed models", help="list installed models",
) )
parser.add_argument( parser.add_argument(
@ -773,7 +778,7 @@ def main():
config.parse_args(invoke_args) config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config) logger = InvokeAILogger().getLogger(config=config)
if not (config.conf_path / 'models.yaml').exists(): if not config.model_conf_path.exists():
logger.info( logger.info(
"Your InvokeAI root directory is not set up. Calling invokeai-configure." "Your InvokeAI root directory is not set up. Calling invokeai-configure."
) )

View File

@ -17,28 +17,20 @@ from shutil import get_terminal_size
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
# minimum size for UIs # minimum size for UIs
MIN_COLS = 130 MIN_COLS = 136
MIN_LINES = 45 MIN_LINES = 45
# ------------------------------------- # -------------------------------------
def set_terminal_size(columns: int, lines: int, launch_command: str=None): def set_terminal_size(columns: int, lines: int):
ts = get_terminal_size() ts = get_terminal_size()
width = max(columns,ts.columns) width = max(columns,ts.columns)
height = max(lines,ts.lines) height = max(lines,ts.lines)
OS = platform.uname().system OS = platform.uname().system
if OS == "Windows": if OS == "Windows":
# The new Windows Terminal doesn't resize, so we relaunch in a CMD window. pass
# Would prefer to use execvpe() here, but somehow it is not working properly # not working reliably - ask user to adjust the window
# in the Windows 10 environment. #_set_terminal_size_powershell(width,height)
if 'IA_RELAUNCHED' not in os.environ:
args=['conhost']
args.extend([launch_command] if launch_command else [sys.argv[0]])
args.extend(sys.argv[1:])
os.environ['IA_RELAUNCHED'] = 'True'
os.execvp('conhost',args)
else:
_set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]: elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width,height) _set_terminal_size_unix(width,height)
@ -84,20 +76,14 @@ def _set_terminal_size_unix(width: int, height: int):
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width)) sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
sys.stdout.flush() sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int, launch_command: str=None): def set_min_terminal_size(min_cols: int, min_lines: int):
# make sure there's enough room for the ui # make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size() term_cols, term_lines = get_terminal_size()
if term_cols >= min_cols and term_lines >= min_lines: if term_cols >= min_cols and term_lines >= min_lines:
return return
cols = max(term_cols, min_cols) cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines) lines = max(term_lines, min_lines)
set_terminal_size(cols, lines, launch_command) set_terminal_size(cols, lines)
# did it work?
term_cols, term_lines = get_terminal_size()
if term_cols < cols or term_lines < lines:
print(f'This window is too small for optimal display. For best results please enlarge it.')
input('After resizing, press any key to continue...')
class IntSlider(npyscreen.Slider): class IntSlider(npyscreen.Slider):
def translate_value(self): def translate_value(self):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-8888b06f.js"></script> <script type="module" crossorigin src="./assets/index-f1a5f9cf.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -399,6 +399,8 @@
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?", "deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"modelDeleted": "Model Deleted",
"modelDeleteFailed": "Failed to delete model",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.", "deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location", "formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.", "formMessageDiffusersModelLocationDesc": "Please enter at least one.",
@ -408,11 +410,13 @@
"convertToDiffusers": "Convert To Diffusers", "convertToDiffusers": "Convert To Diffusers",
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.", "convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.", "convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.", "convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location", "convertToDiffusersSaveLocation": "Save Location",
"noCustomLocationProvided": "No Custom Location Provided",
"convertingModelBegin": "Converting Model. Please wait.",
"v1": "v1", "v1": "v1",
"v2_base": "v2 (512px)", "v2_base": "v2 (512px)",
"v2_768": "v2 (768px)", "v2_768": "v2 (768px)",
@ -450,7 +454,8 @@
"none": "none", "none": "none",
"addDifference": "Add Difference", "addDifference": "Add Difference",
"pickModelType": "Pick Model Type", "pickModelType": "Pick Model Type",
"selectModel": "Select Model" "selectModel": "Select Model",
"importModels": "Import Models"
}, },
"parameters": { "parameters": {
"general": "General", "general": "General",
@ -572,6 +577,7 @@
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"downloadImageStarted": "Image Download Started", "downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied", "imageCopied": "Image Copied",
"problemCopyingImage": "Unable to Copy Image",
"imageLinkCopied": "Image Link Copied", "imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link", "problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded", "imageNotLoaded": "No Image Loaded",
@ -688,6 +694,15 @@
"reloadSchema": "Reload Schema", "reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes", "saveNodes": "Save Nodes",
"loadNodes": "Load Nodes", "loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes" "clearNodes": "Clear Nodes",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View",
"hideGraphNodes": "Hide Graph Overlay",
"showGraphNodes": "Show Graph Overlay",
"hideLegendNodes": "Hide Field Type Legend",
"showLegendNodes": "Show Field Type Legend",
"hideMinimapnodes": "Hide MiniMap",
"showMinimapnodes": "Show MiniMap"
} }
} }

View File

@ -1,7 +1,9 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab, setActiveTab,
toggleGalleryPanel, toggleGalleryPanel,
@ -14,10 +16,11 @@ import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector( const globalHotkeysSelector = createSelector(
(state: RootState) => state.hotkeys, [(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
(hotkeys) => { (hotkeys, ui) => {
const { shift } = hotkeys; const { shift } = hotkeys;
return { shift }; const { shouldPinParametersPanel, shouldPinGallery } = ui;
return { shift, shouldPinGallery, shouldPinParametersPanel };
}, },
{ {
memoizeOptions: { memoizeOptions: {
@ -34,7 +37,10 @@ const globalHotkeysSelector = createSelector(
*/ */
const GlobalHotkeys: React.FC = () => { const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { shift } = useAppSelector(globalHotkeysSelector); const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
globalHotkeysSelector
);
const activeTabName = useAppSelector(activeTabNameSelector);
useHotkeys( useHotkeys(
'*', '*',
@ -51,18 +57,30 @@ const GlobalHotkeys: React.FC = () => {
useHotkeys('o', () => { useHotkeys('o', () => {
dispatch(toggleParametersPanel()); dispatch(toggleParametersPanel());
if (activeTabName === 'unifiedCanvas' && shouldPinParametersPanel) {
dispatch(requestCanvasRescale());
}
}); });
useHotkeys(['shift+o'], () => { useHotkeys(['shift+o'], () => {
dispatch(togglePinParametersPanel()); dispatch(togglePinParametersPanel());
if (activeTabName === 'unifiedCanvas') {
dispatch(requestCanvasRescale());
}
}); });
useHotkeys('g', () => { useHotkeys('g', () => {
dispatch(toggleGalleryPanel()); dispatch(toggleGalleryPanel());
if (activeTabName === 'unifiedCanvas' && shouldPinGallery) {
dispatch(requestCanvasRescale());
}
}); });
useHotkeys(['shift+g'], () => { useHotkeys(['shift+g'], () => {
dispatch(togglePinGalleryPanel()); dispatch(togglePinGalleryPanel());
if (activeTabName === 'unifiedCanvas') {
dispatch(requestCanvasRescale());
}
}); });
useHotkeys('1', () => { useHotkeys('1', () => {

View File

@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -177,6 +179,8 @@ addSocketConnectedListener();
addSocketDisconnectedListener(); addSocketDisconnectedListener();
addSocketSubscribedListener(); addSocketSubscribedListener();
addSocketUnsubscribedListener(); addSocketUnsubscribedListener();
addModelLoadStartedEventListener();
addModelLoadCompletedEventListener();
// Session Created // Session Created
addSessionCreatedPendingListener(); addSessionCreatedPendingListener();

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadCompleted,
socketModelLoadCompleted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadCompletedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadCompleted(action.payload));
},
});
};

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadStarted,
socketModelLoadStarted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadStartedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load started (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadStarted(action.payload));
},
});
};

View File

@ -21,6 +21,7 @@ import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
@ -49,6 +50,7 @@ const allReducers = {
dynamicPrompts: dynamicPromptsReducer, dynamicPrompts: dynamicPromptsReducer,
imageDeletion: imageDeletionReducer, imageDeletion: imageDeletionReducer,
lora: loraReducer, lora: loraReducer,
modelmanager: modelmanagerReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -67,6 +69,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'controlNet', 'controlNet',
'dynamicPrompts', 'dynamicPrompts',
'lora', 'lora',
'modelmanager',
]; ];
export const store = configureStore({ export const store = configureStore({

View File

@ -21,6 +21,7 @@ import { ImageDTO } from 'services/api/types';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable'; import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable'; import IAIDroppable from './IAIDroppable';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
type IAIDndImageProps = { type IAIDndImageProps = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
@ -96,119 +97,124 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}; };
return ( return (
<Flex <ImageContextMenu imageDTO={imageDTO}>
sx={{ {(ref) => (
width: 'full',
height: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}}
>
{imageDTO && (
<Flex <Flex
ref={ref}
sx={{ sx={{
w: 'full', width: 'full',
h: 'full', height: 'full',
position: fitContainer ? 'absolute' : 'relative',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}} }}
> >
<Image {imageDTO && (
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url} <Flex
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{ sx={{
boxSize: 16, w: 'full',
h: 'full',
position: fitContainer ? 'absolute' : 'relative',
alignItems: 'center',
justifyContent: 'center',
}}
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
</Flex>
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}} }}
/> />
</Flex> )}
</> </Flex>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} </ImageContextMenu>
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
)}
</Flex>
); );
}; };

View File

@ -8,19 +8,34 @@ import {
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { ChangeEvent, KeyboardEvent, memo, useCallback } from 'react'; import {
CSSProperties,
ChangeEvent,
KeyboardEvent,
memo,
useCallback,
} from 'react';
interface IAIInputProps extends InputProps { interface IAIInputProps extends InputProps {
label?: string; label?: string;
labelPos?: 'top' | 'side';
value?: string; value?: string;
size?: string; size?: string;
onChange?: (e: ChangeEvent<HTMLInputElement>) => void; onChange?: (e: ChangeEvent<HTMLInputElement>) => void;
formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>; formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>;
} }
const labelPosVerticalStyle: CSSProperties = {
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
gap: 10,
};
const IAIInput = (props: IAIInputProps) => { const IAIInput = (props: IAIInputProps) => {
const { const {
label = '', label = '',
labelPos = 'top',
isDisabled = false, isDisabled = false,
isInvalid, isInvalid,
formControlProps, formControlProps,
@ -51,6 +66,7 @@ const IAIInput = (props: IAIInputProps) => {
isInvalid={isInvalid} isInvalid={isInvalid}
isDisabled={isDisabled} isDisabled={isDisabled}
{...formControlProps} {...formControlProps}
style={labelPos === 'side' ? labelPosVerticalStyle : undefined}
> >
{label !== '' && <FormLabel>{label}</FormLabel>} {label !== '' && <FormLabel>{label}</FormLabel>}
<Input <Input

View File

@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
label: { label: {
color: mode(base700, base300)(colorMode), color: mode(base700, base300)(colorMode),
fontWeight: 'normal', fontWeight: 'normal',
marginBottom: 4,
}, },
})} })}
{...rest} {...rest}

View File

@ -9,14 +9,14 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = Omit<SelectProps, 'label'> & { export type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string; label?: string;
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, label, disabled, ...rest } = props; const { tooltip, inputRef, label, disabled, required, ...rest } = props;
const styles = useMantineSelectStyles(); const styles = useMantineSelectStyles();
@ -25,7 +25,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
<Select <Select
label={ label={
label ? ( label ? (
<FormControl isDisabled={disabled}> <FormControl isRequired={required} isDisabled={disabled}>
<FormLabel>{label}</FormLabel> <FormLabel>{label}</FormLabel>
</FormControl> </FormControl>
) : undefined ) : undefined

View File

@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
/** /**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime. * Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/ */
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime'); export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@ -11,6 +11,7 @@ import {
setIsMouseOverBoundingBox, setIsMouseOverBoundingBox,
setIsMovingBoundingBox, setIsMovingBoundingBox,
setIsTransformingBoundingBox, setIsTransformingBoundingBox,
setShouldSnapToGrid,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import Konva from 'konva'; import Konva from 'konva';
@ -20,6 +21,7 @@ import { Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useCallback, useEffect, useRef, useState } from 'react'; import { useCallback, useEffect, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Group, Rect, Transformer } from 'react-konva'; import { Group, Rect, Transformer } from 'react-konva';
const boundingBoxPreviewSelector = createSelector( const boundingBoxPreviewSelector = createSelector(
@ -91,6 +93,10 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
const scaledStep = 64 * stageScale; const scaledStep = 64 * stageScale;
useHotkeys('N', () => {
dispatch(setShouldSnapToGrid(!shouldSnapToGrid));
});
const handleOnDragMove = useCallback( const handleOnDragMove = useCallback(
(e: KonvaEventObject<DragEvent>) => { (e: KonvaEventObject<DragEvent>) => {
if (!shouldSnapToGrid) { if (!shouldSnapToGrid) {

View File

@ -139,7 +139,7 @@ const IAICanvasToolChooserOptions = () => {
); );
useHotkeys( useHotkeys(
['shift+BracketLeft'], ['Shift+BracketLeft'],
() => { () => {
dispatch( dispatch(
setBrushColor({ setBrushColor({
@ -156,7 +156,7 @@ const IAICanvasToolChooserOptions = () => {
); );
useHotkeys( useHotkeys(
['shift+BracketRight'], ['Shift+BracketRight'],
() => { () => {
dispatch( dispatch(
setBrushColor({ setBrushColor({

View File

@ -48,6 +48,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover'; import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions'; import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
import IAICanvasUndoButton from './IAICanvasUndoButton'; import IAICanvasUndoButton from './IAICanvasUndoButton';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
export const selector = createSelector( export const selector = createSelector(
[systemSelector, canvasSelector, isStagingSelector], [systemSelector, canvasSelector, isStagingSelector],
@ -79,6 +80,7 @@ const IAICanvasToolbar = () => {
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
const { t } = useTranslation(); const { t } = useTranslation();
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
const { openUploader } = useImageUploader(); const { openUploader } = useImageUploader();
@ -136,10 +138,10 @@ const IAICanvasToolbar = () => {
handleCopyImageToClipboard(); handleCopyImageToClipboard();
}, },
{ {
enabled: () => !isStaging, enabled: () => !isStaging && isClipboardAPIAvailable,
preventDefault: true, preventDefault: true,
}, },
[canvasBaseLayer, isProcessing] [canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
); );
useHotkeys( useHotkeys(
@ -189,6 +191,9 @@ const IAICanvasToolbar = () => {
}; };
const handleCopyImageToClipboard = () => { const handleCopyImageToClipboard = () => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard()); dispatch(canvasCopiedToClipboard());
}; };
@ -256,13 +261,15 @@ const IAICanvasToolbar = () => {
onClick={handleSaveToGallery} onClick={handleSaveToGallery}
isDisabled={isStaging} isDisabled={isStaging}
/> />
<IAIIconButton {isClipboardAPIAvailable && (
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`} <IAIIconButton
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`} aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
icon={<FaCopy />} tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
onClick={handleCopyImageToClipboard} icon={<FaCopy />}
isDisabled={isStaging} onClick={handleCopyImageToClipboard}
/> isDisabled={isStaging}
/>
)}
<IAIIconButton <IAIIconButton
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`} aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`} tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}

View File

@ -1,7 +1,16 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { ButtonGroup, Flex, FlexProps, Link } from '@chakra-ui/react'; import {
ButtonGroup,
Flex,
FlexProps,
Link,
Menu,
MenuButton,
MenuItem,
MenuList,
} from '@chakra-ui/react';
// import { runESRGAN, runFacetool } from 'app/socketio/actions'; // import { runESRGAN, runFacetool } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -20,6 +29,7 @@ import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/U
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab, setActiveTab,
@ -48,6 +58,8 @@ import {
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce'; import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { menuListMotionProps } from 'theme/components/menu';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -120,6 +132,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
@ -128,7 +143,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
500 500
); );
const { currentData: image, isFetching } = useGetImageDTOQuery( const { currentData: imageDTO, isFetching } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken lastSelectedImage ?? skipToken
); );
@ -142,15 +157,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleCopyImageLink = useCallback(() => { const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => { const getImageUrl = () => {
if (!image) { if (!imageDTO) {
return; return;
} }
if (image.image_url.startsWith('http')) { if (imageDTO.image_url.startsWith('http')) {
return image.image_url; return imageDTO.image_url;
} }
return window.location.toString() + image.image_url; return window.location.toString() + imageDTO.image_url;
}; };
const url = getImageUrl(); const url = getImageUrl();
@ -174,7 +189,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
isClosable: true, isClosable: true,
}); });
}); });
}, [toaster, t, image]); }, [toaster, t, imageDTO]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata); recallAllParameters(metadata);
@ -192,31 +207,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]); }, [metadata?.seed, recallSeed]);
useHotkeys('s', handleUseSeed, [image]); useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); }, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [imageDTO]);
const handleSendToImageToImage = useCallback(() => { const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img()); dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image)); dispatch(initialImageSelected(imageDTO));
}, [dispatch, image]); }, [dispatch, imageDTO]);
useHotkeys('shift+i', handleSendToImageToImage, [image]); useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
const handleClickUpscale = useCallback(() => { const handleClickUpscale = useCallback(() => {
// selectedImage && dispatch(runESRGAN(selectedImage)); // selectedImage && dispatch(runESRGAN(selectedImage));
}, []); }, []);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
if (!image) { if (!imageDTO) {
return; return;
} }
dispatch(imageToDeleteSelected(image)); dispatch(imageToDeleteSelected(imageDTO));
}, [dispatch, image]); }, [dispatch, imageDTO]);
useHotkeys( useHotkeys(
'Shift+U', 'Shift+U',
@ -236,7 +251,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, },
[ [
isUpscalingEnabled, isUpscalingEnabled,
image, imageDTO,
isESRGANAvailable, isESRGANAvailable,
shouldDisableToolbarButtons, shouldDisableToolbarButtons,
isConnected, isConnected,
@ -268,7 +283,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
[ [
isFaceRestoreEnabled, isFaceRestoreEnabled,
image, imageDTO,
isGFPGANAvailable, isGFPGANAvailable,
shouldDisableToolbarButtons, shouldDisableToolbarButtons,
isConnected, isConnected,
@ -283,10 +298,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
); );
const handleSendToCanvas = useCallback(() => { const handleSendToCanvas = useCallback(() => {
if (!image) return; if (!imageDTO) return;
dispatch(sentImageToCanvas()); dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image)); dispatch(setInitialCanvasImage(imageDTO));
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') { if (activeTabName !== 'unifiedCanvas') {
@ -299,12 +314,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
duration: 2500, duration: 2500,
isClosable: true, isClosable: true,
}); });
}, [image, dispatch, activeTabName, toaster, t]); }, [imageDTO, dispatch, activeTabName, toaster, t]);
useHotkeys( useHotkeys(
'i', 'i',
() => { () => {
if (image) { if (imageDTO) {
handleClickShowImageDetails(); handleClickShowImageDetails();
} else { } else {
toaster({ toaster({
@ -315,13 +330,20 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}); });
} }
}, },
[image, shouldShowImageDetails, toaster] [imageDTO, shouldShowImageDetails, toaster]
); );
const handleClickProgressImagesToggle = useCallback(() => { const handleClickProgressImagesToggle = useCallback(() => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer)); dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]); }, [dispatch, shouldShowProgressInViewer]);
const handleCopyImage = useCallback(() => {
if (!imageDTO) {
return;
}
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO]);
return ( return (
<> <>
<Flex <Flex
@ -334,63 +356,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{...props} {...props}
> >
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIPopover <Menu>
triggerComponent={ <MenuButton
<IAIIconButton as={IAIIconButton}
aria-label={`${t('parameters.sendTo')}...`} aria-label={`${t('parameters.sendTo')}...`}
tooltip={`${t('parameters.sendTo')}...`} tooltip={`${t('parameters.sendTo')}...`}
isDisabled={!image} isDisabled={!imageDTO}
icon={<FaShareAlt />} icon={<FaShareAlt />}
/> />
} <MenuList motionProps={menuListMotionProps}>
> {imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}
<Flex </MenuList>
sx={{ </Menu>
flexDirection: 'column',
rowGap: 2,
}}
>
<IAIButton
size="sm"
onClick={handleSendToImageToImage}
leftIcon={<FaShare />}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</IAIButton>
{isCanvasEnabled && (
<IAIButton
size="sm"
onClick={handleSendToCanvas}
leftIcon={<FaShare />}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</IAIButton>
)}
{/* <IAIButton
size="sm"
onClick={handleCopyImage}
leftIcon={<FaCopy />}
>
{t('parameters.copyImage')}
</IAIButton> */}
<IAIButton
size="sm"
onClick={handleCopyImageLink}
leftIcon={<FaCopy />}
>
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={image?.image_url} target="_blank">
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
</Link>
</Flex>
</IAIPopover>
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
@ -443,7 +420,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton <IAIButton
isDisabled={ isDisabled={
!isGFPGANAvailable || !isGFPGANAvailable ||
!image || !imageDTO ||
!(isConnected && !isProcessing) || !(isConnected && !isProcessing) ||
!facetoolStrength !facetoolStrength
} }
@ -474,7 +451,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton <IAIButton
isDisabled={ isDisabled={
!isESRGANAvailable || !isESRGANAvailable ||
!image || !imageDTO ||
!(isConnected && !isProcessing) || !(isConnected && !isProcessing) ||
!upscalingLevel !upscalingLevel
} }

View File

@ -4,13 +4,14 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import { memo, useMemo } from 'react'; import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems'; import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
import SingleSelectionMenuItems from './SingleSelectionMenuItems'; import SingleSelectionMenuItems from './SingleSelectionMenuItems';
type Props = { type Props = {
imageDTO: ImageDTO; imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children']; children: ContextMenuProps<HTMLDivElement>['children'];
}; };
@ -31,18 +32,32 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
const { selectionCount } = useAppSelector(selector); const { selectionCount } = useAppSelector(selector);
const handleContextMenu = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.preventDefault();
}, []);
return ( return (
<ContextMenu<HTMLDivElement> <ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => ( menuButtonProps={{
<MenuList sx={{ visibility: 'visible !important' }}> bg: 'transparent',
{selectionCount === 1 ? ( _hover: { bg: 'transparent' },
<SingleSelectionMenuItems imageDTO={imageDTO} /> }}
) : ( renderMenu={() =>
<MultipleSelectionMenuItems /> imageDTO ? (
)} <MenuList
</MenuList> sx={{ visibility: 'visible !important' }}
)} motionProps={menuListMotionProps}
onContextMenu={handleContextMenu}
>
{selectionCount === 1 ? (
<SingleSelectionMenuItems imageDTO={imageDTO} />
) : (
<MultipleSelectionMenuItems />
)}
</MenuList>
) : null
}
> >
{children} {children}
</ContextMenu> </ContextMenu>

View File

@ -1,5 +1,4 @@
import { ExternalLinkIcon } from '@chakra-ui/icons'; import { Link, MenuItem } from '@chakra-ui/react';
import { MenuItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
@ -14,11 +13,21 @@ import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletio
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback, useContext, useMemo } from 'react'; import { memo, useCallback, useContext, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaFolder, FaShare, FaTrash } from 'react-icons/fa'; import {
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; FaAsterisk,
FaCopy,
FaDownload,
FaExternalLinkAlt,
FaFolder,
FaQuoteRight,
FaSeedling,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages'; import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
@ -61,6 +70,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name); const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const metadata = currentData?.metadata; const metadata = currentData?.metadata;
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
@ -130,13 +142,27 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
dispatch(imagesAddedToBatch([imageDTO.image_name])); dispatch(imagesAddedToBatch([imageDTO.image_name]));
}, [dispatch, imageDTO.image_name]); }, [dispatch, imageDTO.image_name]);
const handleCopyImage = useCallback(() => {
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO.image_url]);
return ( return (
<> <>
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}> <Link href={imageDTO.image_url} target="_blank">
{t('common.openInNewTab')} <MenuItem
</MenuItem> icon={<FaExternalLinkAlt />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
</Link>
{isClipboardAPIAvailable && (
<MenuItem icon={<FaCopy />} onClickCapture={handleCopyImage}>
{t('parameters.copyImage')}
</MenuItem>
)}
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaQuoteRight />}
onClickCapture={handleRecallPrompt} onClickCapture={handleRecallPrompt}
isDisabled={ isDisabled={
metadata?.positive_prompt === undefined && metadata?.positive_prompt === undefined &&
@ -147,14 +173,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaSeedling />}
onClickCapture={handleRecallSeed} onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined} isDisabled={metadata?.seed === undefined}
> >
{t('parameters.useSeed')} {t('parameters.useSeed')}
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaAsterisk />}
onClickCapture={handleUseAllParameters} onClickCapture={handleUseAllParameters}
isDisabled={!metadata} isDisabled={!metadata}
> >
@ -193,6 +219,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
Remove from Board Remove from Board
</MenuItem> </MenuItem>
)} )}
<Link download={true} href={imageDTO.image_url} target="_blank">
<MenuItem icon={<FaDownload />} w="100%">
{t('parameters.downloadImage')}
</MenuItem>
</Link>
<MenuItem <MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }} sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />} icon={<FaTrash />}

View File

@ -16,14 +16,13 @@ import {
ASSETS_CATEGORIES, ASSETS_CATEGORIES,
IMAGE_CATEGORIES, IMAGE_CATEGORIES,
IMAGE_LIMIT, IMAGE_LIMIT,
selectImagesAll,
} from 'features/gallery//store/gallerySlice'; } from 'features/gallery//store/gallerySlice';
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors'; import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
import { VirtuosoGrid } from 'react-virtuoso'; import { VirtuosoGrid } from 'react-virtuoso';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
import ImageGridItemContainer from './ImageGridItemContainer'; import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer'; import ImageGridListContainer from './ImageGridListContainer';
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
const selector = createSelector( const selector = createSelector(
[stateSelector, selectFilteredImages], [stateSelector, selectFilteredImages],
@ -180,7 +179,6 @@ const GalleryImageGrid = () => {
</Box> </Box>
); );
} }
console.log({ selectedBoardId });
if (status !== 'rejected') { if (status !== 'rejected') {
return ( return (

View File

@ -110,8 +110,11 @@ const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
return ( return (
<div ref={ref} {...others}> <div ref={ref} {...others}>
<div> <div>
<Text>{label}</Text> <Text fontWeight={600}>{label}</Text>
<Text size="xs" color="base.600"> <Text
size="xs"
sx={{ color: 'base.600', _dark: { color: 'base.500' } }}
>
{description} {description}
</Text> </Text>
</div> </div>

View File

@ -20,8 +20,8 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
justifyContent: 'space-between', justifyContent: 'space-between',
px: 2, px: 2,
py: 1, py: 1,
bg: 'base.300', bg: 'base.100',
_dark: { bg: 'base.700' }, _dark: { bg: 'base.900' },
}} }}
> >
<Tooltip label={nodeId}> <Tooltip label={nodeId}>
@ -30,7 +30,7 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
sx={{ sx={{
fontWeight: 600, fontWeight: 600,
color: 'base.900', color: 'base.900',
_dark: { color: 'base.100' }, _dark: { color: 'base.200' },
}} }}
> >
{title} {title}

View File

@ -151,7 +151,6 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
nodeId={nodeId} nodeId={nodeId}
field={field} field={field}
template={template} template={template}
base_models={['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']}
/> />
); );
} }

View File

@ -59,7 +59,7 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
flexDirection: 'column', flexDirection: 'column',
borderBottomRadius: 'md', borderBottomRadius: 'md',
py: 2, py: 2,
bg: 'base.200', bg: 'base.150',
_dark: { bg: 'base.800' }, _dark: { bg: 'base.800' },
}} }}
> >

View File

@ -1,9 +1,9 @@
import 'reactflow/dist/style.css';
import { Box } from '@chakra-ui/react'; import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow'; import { ReactFlowProvider } from 'reactflow';
import 'reactflow/dist/style.css';
import { Flow } from './Flow';
import { memo } from 'react'; import { memo } from 'react';
import { Flow } from './Flow';
const NodeEditor = () => { const NodeEditor = () => {
return ( return (

View File

@ -1,9 +1,9 @@
import { Box, useToken } from '@chakra-ui/react'; import { Box, useToken } from '@chakra-ui/react';
import { NODE_MIN_WIDTH } from 'app/constants'; import { NODE_MIN_WIDTH } from 'app/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { PropsWithChildren } from 'react'; import { PropsWithChildren } from 'react';
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation'; import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
import { useAppSelector } from 'app/store/storeHooks';
type NodeWrapperProps = PropsWithChildren & { type NodeWrapperProps = PropsWithChildren & {
selected: boolean; selected: boolean;

View File

@ -1,17 +1,36 @@
import { ButtonGroup } from '@chakra-ui/react'; import { ButtonGroup, Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa'; import {
FaCode,
FaExpand,
FaMinus,
FaPlus,
FaInfo,
FaMapMarkerAlt,
} from 'react-icons/fa';
import { useReactFlow } from 'reactflow'; import { useReactFlow } from 'reactflow';
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice'; import { useTranslation } from 'react-i18next';
import {
shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
} from '../store/nodesSlice';
const ViewportControls = () => { const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow(); const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const shouldShowGraphOverlay = useAppSelector( const shouldShowGraphOverlay = useAppSelector(
(state) => state.nodes.shouldShowGraphOverlay (state) => state.nodes.shouldShowGraphOverlay
); );
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
const shouldShowMinimapPanel = useAppSelector(
(state) => state.nodes.shouldShowMinimapPanel
);
const handleClickedZoomIn = useCallback(() => { const handleClickedZoomIn = useCallback(() => {
zoomIn(); zoomIn();
@ -29,29 +48,64 @@ const ViewportControls = () => {
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay)); dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
}, [shouldShowGraphOverlay, dispatch]); }, [shouldShowGraphOverlay, dispatch]);
const handleClickedToggleFieldTypeLegend = useCallback(() => {
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
}, [shouldShowFieldTypeLegend, dispatch]);
const handleClickedToggleMiniMapPanel = useCallback(() => {
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
}, [shouldShowMinimapPanel, dispatch]);
return ( return (
<ButtonGroup isAttached orientation="vertical"> <ButtonGroup isAttached orientation="vertical">
<IAIIconButton <Tooltip label={t('nodes.zoomInNodes')}>
onClick={handleClickedZoomIn} <IAIIconButton onClick={handleClickedZoomIn} icon={<FaPlus />} />
aria-label="Zoom In" </Tooltip>
icon={<FaPlus />} <Tooltip label={t('nodes.zoomOutNodes')}>
/> <IAIIconButton onClick={handleClickedZoomOut} icon={<FaMinus />} />
<IAIIconButton </Tooltip>
onClick={handleClickedZoomOut} <Tooltip label={t('nodes.fitViewportNodes')}>
aria-label="Zoom Out" <IAIIconButton onClick={handleClickedFitView} icon={<FaExpand />} />
icon={<FaMinus />} </Tooltip>
/> <Tooltip
<IAIIconButton label={
onClick={handleClickedFitView} shouldShowGraphOverlay
aria-label="Fit to Viewport" ? t('nodes.hideGraphNodes')
icon={<FaExpand />} : t('nodes.showGraphNodes')
/> }
<IAIIconButton >
isChecked={shouldShowGraphOverlay} <IAIIconButton
onClick={handleClickedToggleGraphOverlay} isChecked={shouldShowGraphOverlay}
aria-label="Show/Hide Graph" onClick={handleClickedToggleGraphOverlay}
icon={<FaCode />} icon={<FaCode />}
/> />
</Tooltip>
<Tooltip
label={
shouldShowFieldTypeLegend
? t('nodes.hideLegendNodes')
: t('nodes.showLegendNodes')
}
>
<IAIIconButton
isChecked={shouldShowFieldTypeLegend}
onClick={handleClickedToggleFieldTypeLegend}
icon={<FaInfo />}
/>
</Tooltip>
<Tooltip
label={
shouldShowMinimapPanel
? t('nodes.hideMinimapnodes')
: t('nodes.showMinimapnodes')
}
>
<IAIIconButton
isChecked={shouldShowMinimapPanel}
onClick={handleClickedToggleMiniMapPanel}
icon={<FaMapMarkerAlt />}
/>
</Tooltip>
</ButtonGroup> </ButtonGroup>
); );
}; };

View File

@ -1,3 +1,5 @@
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useColorModeValue } from '@chakra-ui/react'; import { useColorModeValue } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
import { MiniMap } from 'reactflow'; import { MiniMap } from 'reactflow';
@ -12,6 +14,10 @@ const MinimapPanel = () => {
} }
); );
const shouldShowMinimapPanel = useAppSelector(
(state: RootState) => state.nodes.shouldShowMinimapPanel
);
const nodeColor = useColorModeValue( const nodeColor = useColorModeValue(
'var(--invokeai-colors-accent-300)', 'var(--invokeai-colors-accent-300)',
'var(--invokeai-colors-accent-700)' 'var(--invokeai-colors-accent-700)'
@ -23,15 +29,19 @@ const MinimapPanel = () => {
); );
return ( return (
<MiniMap <>
nodeStrokeWidth={3} {shouldShowMinimapPanel && (
pannable <MiniMap
zoomable nodeStrokeWidth={3}
nodeBorderRadius={30} pannable
style={miniMapStyle} zoomable
nodeColor={nodeColor} nodeBorderRadius={30}
maskColor={maskColor} style={miniMapStyle}
/> nodeColor={nodeColor}
maskColor={maskColor}
/>
)}
</>
); );
}; };

View File

@ -9,10 +9,13 @@ const TopRightPanel = () => {
const shouldShowGraphOverlay = useAppSelector( const shouldShowGraphOverlay = useAppSelector(
(state: RootState) => state.nodes.shouldShowGraphOverlay (state: RootState) => state.nodes.shouldShowGraphOverlay
); );
const shouldShowFieldTypeLegend = useAppSelector(
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
);
return ( return (
<Panel position="top-right"> <Panel position="top-right">
<FieldTypeLegend /> {shouldShowFieldTypeLegend && <FieldTypeLegend />}
{shouldShowGraphOverlay && <NodeGraphOverlay />} {shouldShowGraphOverlay && <NodeGraphOverlay />}
</Panel> </Panel>
); );

View File

@ -32,6 +32,8 @@ export type NodesState = {
invocationTemplates: Record<string, InvocationTemplate>; invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null; connectionStartParams: OnConnectStartParams | null;
shouldShowGraphOverlay: boolean; shouldShowGraphOverlay: boolean;
shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean;
editorInstance: ReactFlowInstance | undefined; editorInstance: ReactFlowInstance | undefined;
}; };
@ -42,6 +44,8 @@ export const initialNodesState: NodesState = {
invocationTemplates: {}, invocationTemplates: {},
connectionStartParams: null, connectionStartParams: null,
shouldShowGraphOverlay: false, shouldShowGraphOverlay: false,
shouldShowFieldTypeLegend: false,
shouldShowMinimapPanel: true,
editorInstance: undefined, editorInstance: undefined,
}; };
@ -125,6 +129,15 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => { shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload; state.shouldShowGraphOverlay = action.payload;
}, },
shouldShowFieldTypeLegendChanged: (
state,
action: PayloadAction<boolean>
) => {
state.shouldShowFieldTypeLegend = action.payload;
},
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
},
nodeTemplatesBuilt: ( nodeTemplatesBuilt: (
state, state,
action: PayloadAction<Record<string, InvocationTemplate>> action: PayloadAction<Record<string, InvocationTemplate>>
@ -161,6 +174,8 @@ export const {
connectionStarted, connectionStarted,
connectionEnded, connectionEnded,
shouldShowGraphOverlayChanged, shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
nodeTemplatesBuilt, nodeTemplatesBuilt,
nodeEditorReset, nodeEditorReset,
imageCollectionFieldValueChanged, imageCollectionFieldValueChanged,

View File

@ -36,7 +36,7 @@ const ParamMainModelSelect = () => {
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if (!model) { if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) {
return; return;
} }

View File

@ -15,7 +15,7 @@ import {
ModalOverlay, ModalOverlay,
useDisclosure, useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { cloneElement, ReactElement } from 'react'; import { ReactElement, cloneElement } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import HotkeysModalItem from './HotkeysModalItem'; import HotkeysModalItem from './HotkeysModalItem';
@ -65,11 +65,6 @@ export default function HotkeysModal({ children }: HotkeysModalProps) {
desc: t('hotkeys.pinOptions.desc'), desc: t('hotkeys.pinOptions.desc'),
hotkey: 'Shift+O', hotkey: 'Shift+O',
}, },
{
title: t('hotkeys.toggleViewer.title'),
desc: t('hotkeys.toggleViewer.desc'),
hotkey: 'Z',
},
{ {
title: t('hotkeys.toggleGallery.title'), title: t('hotkeys.toggleGallery.title'),
desc: t('hotkeys.toggleGallery.desc'), desc: t('hotkeys.toggleGallery.desc'),
@ -85,12 +80,6 @@ export default function HotkeysModal({ children }: HotkeysModalProps) {
desc: t('hotkeys.changeTabs.desc'), desc: t('hotkeys.changeTabs.desc'),
hotkey: '1-5', hotkey: '1-5',
}, },
{
title: t('hotkeys.consoleToggle.title'),
desc: t('hotkeys.consoleToggle.desc'),
hotkey: '`',
},
]; ];
const generalHotkeys = [ const generalHotkeys = [
@ -109,11 +98,6 @@ export default function HotkeysModal({ children }: HotkeysModalProps) {
desc: t('hotkeys.setParameters.desc'), desc: t('hotkeys.setParameters.desc'),
hotkey: 'A', hotkey: 'A',
}, },
{
title: t('hotkeys.restoreFaces.title'),
desc: t('hotkeys.restoreFaces.desc'),
hotkey: 'Shift+R',
},
{ {
title: t('hotkeys.upscale.title'), title: t('hotkeys.upscale.title'),
desc: t('hotkeys.upscale.desc'), desc: t('hotkeys.upscale.desc'),

View File

@ -183,7 +183,7 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
> >
<ModalOverlay /> <ModalOverlay />
<ModalContent> <ModalContent>
<ModalHeader>{t('common.settingsLabel')}</ModalHeader> <ModalHeader bg="none">{t('common.settingsLabel')}</ModalHeader>
<ModalCloseButton /> <ModalCloseButton />
<ModalBody> <ModalBody>
<Flex sx={{ gap: 4, flexDirection: 'column' }}> <Flex sx={{ gap: 4, flexDirection: 'column' }}>
@ -331,12 +331,15 @@ export default SettingsModal;
const StyledFlex = (props: PropsWithChildren) => { const StyledFlex = (props: PropsWithChildren) => {
return ( return (
<Flex <Flex
layerStyle="second"
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
gap: 2, gap: 2,
p: 4, p: 4,
borderRadius: 'base', borderRadius: 'base',
bg: 'base.100',
_dark: {
bg: 'base.900',
},
}} }}
> >
{props.children} {props.children}

View File

@ -1,11 +1,10 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next'; import { t } from 'i18next';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { imageUploaded } from 'services/api/thunks/image'; import { imageUploaded } from 'services/api/thunks/image';
import { import {
@ -44,8 +43,6 @@ export interface SystemState {
isCancelable: boolean; isCancelable: boolean;
enableImageDebugging: boolean; enableImageDebugging: boolean;
toastQueue: UseToastOptions[]; toastQueue: UseToastOptions[];
searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null;
/** /**
* The current progress image * The current progress image
*/ */
@ -79,7 +76,7 @@ export interface SystemState {
*/ */
consoleLogLevel: InvokeLogLevel; consoleLogLevel: InvokeLogLevel;
shouldLogToConsole: boolean; shouldLogToConsole: boolean;
statusTranslationKey: TFuncKey; statusTranslationKey: any;
/** /**
* When a session is canceled, its ID is stored here until a new session is created. * When a session is canceled, its ID is stored here until a new session is created.
*/ */
@ -106,8 +103,6 @@ export const initialSystemState: SystemState = {
isCancelable: true, isCancelable: true,
enableImageDebugging: false, enableImageDebugging: false,
toastQueue: [], toastQueue: [],
searchFolder: null,
foundModels: null,
progressImage: null, progressImage: null,
shouldAntialiasProgressImage: false, shouldAntialiasProgressImage: false,
sessionId: null, sessionId: null,
@ -132,7 +127,7 @@ export const systemSlice = createSlice({
setIsProcessing: (state, action: PayloadAction<boolean>) => { setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload; state.isProcessing = action.payload;
}, },
setCurrentStatus: (state, action: PayloadAction<TFuncKey>) => { setCurrentStatus: (state, action: any) => {
state.statusTranslationKey = action.payload; state.statusTranslationKey = action.payload;
}, },
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => { setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
@ -153,15 +148,6 @@ export const systemSlice = createSlice({
clearToastQueue: (state) => { clearToastQueue: (state) => {
state.toastQueue = []; state.toastQueue = [];
}, },
setSearchFolder: (state, action: PayloadAction<string | null>) => {
state.searchFolder = action.payload;
},
setFoundModels: (
state,
action: PayloadAction<InvokeAI.FoundModel[] | null>
) => {
state.foundModels = action.payload;
},
/** /**
* A cancel was scheduled * A cancel was scheduled
*/ */
@ -426,8 +412,6 @@ export const {
setEnableImageDebugging, setEnableImageDebugging,
addToast, addToast,
clearToastQueue, clearToastQueue,
setSearchFolder,
setFoundModels,
cancelScheduled, cancelScheduled,
scheduledCancelAborted, scheduledCancelAborted,
cancelTypeChanged, cancelTypeChanged,

View File

@ -1,11 +1,11 @@
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react'; import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
import i18n from 'i18n'; import i18n from 'i18n';
import { ReactNode, memo } from 'react'; import { ReactNode, memo } from 'react';
import AddModelsPanel from './subpanels/AddModelsPanel'; import ImportModelsPanel from './subpanels/ImportModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel'; import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel'; import ModelManagerPanel from './subpanels/ModelManagerPanel';
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels'; type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels';
type ModelManagerTabInfo = { type ModelManagerTabInfo = {
id: ModelManagerTabName; id: ModelManagerTabName;
@ -20,9 +20,9 @@ const tabs: ModelManagerTabInfo[] = [
content: <ModelManagerPanel />, content: <ModelManagerPanel />,
}, },
{ {
id: 'addModels', id: 'importModels',
label: i18n.t('modelManager.addModel'), label: i18n.t('modelManager.importModels'),
content: <AddModelsPanel />, content: <ImportModelsPanel />,
}, },
{ {
id: 'mergeModels', id: 'mergeModels',
@ -46,7 +46,7 @@ const ModelManagerTab = () => {
</Tab> </Tab>
))} ))}
</TabList> </TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}> <TabPanels sx={{ w: 'full', h: 'full' }}>
{tabs.map((tab) => ( {tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}> <TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content} {tab.content}

View File

@ -0,0 +1,29 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
type ModelManagerState = {
searchFolder: string | null;
advancedAddScanModel: string | null;
};
const initialModelManagerState: ModelManagerState = {
searchFolder: null,
advancedAddScanModel: null,
};
export const modelManagerSlice = createSlice({
name: 'modelmanager',
initialState: initialModelManagerState,
reducers: {
setSearchFolder: (state, action: PayloadAction<string | null>) => {
state.searchFolder = action.payload;
},
setAdvancedAddScanModel: (state, action: PayloadAction<string | null>) => {
state.advancedAddScanModel = action.payload;
},
},
});
export const { setSearchFolder, setAdvancedAddScanModel } =
modelManagerSlice.actions;
export default modelManagerSlice.reducer;

View File

@ -0,0 +1,3 @@
import { RootState } from 'app/store/store';
export const modelmanagerSelector = (state: RootState) => state.modelmanager;

View File

@ -1,43 +0,0 @@
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
export default function AddModelsPanel() {
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<Flex columnGap={4}>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
isChecked={addNewModelUIOption == 'ckpt'}
>
{t('modelManager.addCheckpointModel')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
isChecked={addNewModelUIOption == 'diffusers'}
>
{t('modelManager.addDiffuserModel')}
</IAIButton>
</Flex>
<Divider />
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
</Flex>
);
}

View File

@ -1,337 +0,0 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import React from 'react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import SearchModels from './SearchModels';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function AddCheckpointModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeModelConfigProps = {
name: '',
description: '',
config: 'configs/stable-diffusion/v1-inference.yaml',
weights: '',
vae: '',
width: 512,
height: 512,
format: 'ckpt',
default: false,
};
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(addNewModel(values));
dispatch(setAddNewModelUIOption(null));
};
const [addManually, setAddmanually] = React.useState<boolean>(false);
return (
<VStack gap={2} alignItems="flex-start">
<Flex columnGap={4}>
<IAISimpleCheckbox
isChecked={!addManually}
label={t('modelManager.scanForModels')}
onChange={() => setAddmanually(!addManually)}
/>
<IAISimpleCheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
</Flex>
{addManually ? (
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')}
</Text>
{/* Name */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="full"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Description */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Config */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Weights */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* VAE */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<HStack width="100%">
{/* Width */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Height */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
) : (
<SearchModels />
)}
</VStack>
);
}

View File

@ -1,259 +0,0 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Text,
VStack,
} from '@chakra-ui/react';
import { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeDiffusersModelConfigProps = {
name: '',
description: '',
repo_id: '',
path: '',
format: 'diffusers',
default: false,
vae: {
repo_id: '',
path: '',
},
};
const addModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToAdd = values;
if (values.path === '') delete diffusersModelToAdd.path;
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
dispatch(addNewModel(diffusersModelToAdd));
dispatch(setAddNewModelUIOption(null));
};
return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
isRequired
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
isRequired
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersModelLocationDesc')}
</Text>
{/* Path */}
<FormControl isInvalid={!!errors.path && touched.path}>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field as={IAIInput} id="path" name="path" type="text" />
{!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
/>
{!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.repoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersVAELocationDesc')}
</Text>
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeRepoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
);
}

View File

@ -0,0 +1,48 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels';
export default function AddModels() {
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>(
'simple'
);
return (
<Flex
flexDirection="column"
width="100%"
overflow="scroll"
maxHeight={window.innerHeight - 250}
gap={4}
>
<ButtonGroup isAttached>
<IAIButton
size="sm"
isChecked={addModelMode == 'simple'}
onClick={() => setAddModelMode('simple')}
>
Simple
</IAIButton>
<IAIButton
size="sm"
isChecked={addModelMode == 'advanced'}
onClick={() => setAddModelMode('advanced')}
>
Advanced
</IAIButton>
</ButtonGroup>
<Flex
sx={{
p: 4,
borderRadius: 4,
background: 'base.200',
_dark: { background: 'base.800' },
}}
>
{addModelMode === 'simple' && <SimpleAddModels />}
{addModelMode === 'advanced' && <AdvancedAddModels />}
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,143 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { addToast } from 'features/system/store/systemSlice';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type AdvancedAddCheckpointProps = {
model_path?: string;
};
export default function AdvancedAddCheckpoint(
props: AdvancedAddCheckpointProps
) {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
initialValues: {
model_name: model_path
? model_path.split('\\').splice(-1)[0].split('.')[0]
: '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'checkpoint',
error: undefined,
vae: '',
variant: 'normal',
config: 'configs\\stable-diffusion\\v1-inference.yaml',
},
});
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const advancedAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
advancedAddCheckpointForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={advancedAddCheckpointForm.onSubmit((v) =>
advancedAddCheckpointFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
label="Model Name"
required
{...advancedAddCheckpointForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...advancedAddCheckpointForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
label="Model Location"
required
{...advancedAddCheckpointForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...advancedAddCheckpointForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...advancedAddCheckpointForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...advancedAddCheckpointForm.getInputProps('variant')}
/>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
required
width="100%"
{...advancedAddCheckpointForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label="Custom Config File Location"
{...advancedAddCheckpointForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</Flex>
</form>
);
}

View File

@ -0,0 +1,113 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { addToast } from 'features/system/store/systemSlice';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type AdvancedAddDiffusersProps = {
model_path?: string;
};
export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const [addMainModel] = useAddMainModelsMutation();
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
initialValues: {
model_name: model_path ? model_path.split('\\').splice(-1)[0] : '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'diffusers',
error: undefined,
vae: '',
variant: 'normal',
},
});
const advancedAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
advancedAddDiffusersForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={advancedAddDiffusersForm.onSubmit((v) =>
advancedAddDiffusersFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
required
label="Model Name"
{...advancedAddDiffusersForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...advancedAddDiffusersForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
required
label="Model Location"
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
{...advancedAddDiffusersForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...advancedAddDiffusersForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...advancedAddDiffusersForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...advancedAddDiffusersForm.getInputProps('variant')}
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</form>
);
}

View File

@ -0,0 +1,46 @@
import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useState } from 'react';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
export const advancedAddModeData: SelectItem[] = [
{ label: 'Diffusers', value: 'diffusers' },
{ label: 'Checkpoint / Safetensors', value: 'checkpoint' },
];
export type ManualAddMode = 'diffusers' | 'checkpoint';
export default function AdvancedAddModels() {
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
return (
<Flex flexDirection="column" gap={4} width="100%">
<IAIMantineSelect
label="Model Type"
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) return;
setAdvancedAddMode(v as ManualAddMode);
}}
/>
<Flex
sx={{
p: 4,
borderRadius: 4,
bg: 'base.300',
_dark: {
bg: 'base.850',
},
}}
>
{advancedAddMode === 'diffusers' && <AdvancedAddDiffusers />}
{advancedAddMode === 'checkpoint' && <AdvancedAddCheckpoint />}
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,253 @@
import { Flex, Text } from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIScrollArea from 'common/components/IAIScrollArea';
import { addToast } from 'features/system/store/systemSlice';
import { difference, forEach, intersection, map, values } from 'lodash-es';
import { ChangeEvent, MouseEvent, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
SearchFolderResponse,
useGetMainModelsQuery,
useGetModelsInFolderQuery,
useImportMainModelsMutation,
} from 'services/api/endpoints/models';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
export default function FoundModelsList() {
const searchFolder = useAppSelector(
(state: RootState) => state.modelmanager.searchFolder
);
const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery();
// Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } =
useGetModelsInFolderQuery(
{
search_path: searchFolder ? searchFolder : '',
},
{
selectFromResult: ({ data }) => {
const installedModelValues = values(installedModels?.entities);
const installedModelPaths = map(installedModelValues, 'path');
// Only select models those that aren't already installed to Invoke
const notInstalledModels = difference(data, installedModelPaths);
const alreadyInstalled = intersection(data, installedModelPaths);
return {
foundModels: data,
alreadyInstalled: foundModelsFilter(alreadyInstalled, nameFilter),
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
};
},
}
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const quickAddHandler = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
const model_name = e.currentTarget.id.split('\\').splice(-1)[0];
importMainModel({
body: {
location: e.currentTarget.id,
},
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Added Model: ${model_name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Faile To Add Model',
status: 'error',
})
)
);
}
});
},
[dispatch, importMainModel]
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
const renderModels = ({
models,
showActions = true,
}: {
models: string[];
showActions?: boolean;
}) => {
return models.map((model) => {
return (
<Flex
sx={{
p: 4,
gap: 4,
alignItems: 'center',
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
key={model}
>
<Flex w="100%" sx={{ flexDirection: 'column', minW: '25%' }}>
<Text sx={{ fontWeight: 600 }}>
{model.split('\\').slice(-1)[0]}
</Text>
<Text
sx={{
fontSize: 'sm',
color: 'base.600',
_dark: {
color: 'base.400',
},
}}
>
{model}
</Text>
</Flex>
{showActions ? (
<Flex gap={2}>
<IAIButton
id={model}
onClick={quickAddHandler}
isLoading={isLoading}
>
Quick Add
</IAIButton>
<IAIButton
onClick={() => dispatch(setAdvancedAddScanModel(model))}
isLoading={isLoading}
>
Advanced
</IAIButton>
</Flex>
) : (
<Text
sx={{
fontWeight: 600,
p: 2,
borderRadius: 4,
color: 'accent.50',
bg: 'accent.400',
_dark: {
color: 'accent.100',
bg: 'accent.600',
},
}}
>
Installed
</Text>
)}
</Flex>
);
});
};
const renderFoundModels = () => {
if (!searchFolder) return;
if (!foundModels || foundModels.length === 0) {
return (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
height: 96,
userSelect: 'none',
bg: 'base.200',
_dark: {
bg: 'base.900',
},
}}
>
<Text variant="subtext">No Models Found</Text>
</Flex>
);
}
return (
<Flex
sx={{
flexDirection: 'column',
gap: 2,
w: '100%',
minW: '50%',
}}
>
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
labelPos="side"
/>
<Flex p={2} gap={2}>
<Text sx={{ fontWeight: 600 }}>
Models Found: {foundModels.length}
</Text>
<Text
sx={{
fontWeight: 600,
color: 'accent.500',
_dark: {
color: 'accent.200',
},
}}
>
Not Installed: {filteredModels.length}
</Text>
</Flex>
<IAIScrollArea offsetScrollbars>
<Flex gap={2} flexDirection="column">
{renderModels({ models: filteredModels })}
{renderModels({ models: alreadyInstalled, showActions: false })}
</Flex>
</IAIScrollArea>
</Flex>
);
};
return renderFoundModels();
}
const foundModelsFilter = (
data: SearchFolderResponse | undefined,
nameFilter: string
) => {
const filteredModels: SearchFolderResponse = [];
forEach(data, (model) => {
if (!model) {
return;
}
if (model.includes(nameFilter)) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -0,0 +1,97 @@
import { Box, Flex, Text } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { motion } from 'framer-motion';
import { useEffect, useState } from 'react';
import { FaTimes } from 'react-icons/fa';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
import { ManualAddMode, advancedAddModeData } from './AdvancedAddModels';
export default function ScanAdvancedAddModels() {
const advancedAddScanModel = useAppSelector(
(state: RootState) => state.modelmanager.advancedAddScanModel
);
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
const [isCheckpoint, setIsCheckpoint] = useState<boolean>(true);
useEffect(() => {
advancedAddScanModel &&
['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) =>
advancedAddScanModel.endsWith(ext)
)
? setAdvancedAddMode('checkpoint')
: setAdvancedAddMode('diffusers');
}, [advancedAddScanModel, setAdvancedAddMode, isCheckpoint]);
const dispatch = useAppDispatch();
return (
advancedAddScanModel && (
<Box
as={motion.div}
initial={{ x: -100, opacity: 0 }}
animate={{ x: 0, opacity: 1, transition: { duration: 0.2 } }}
sx={{
display: 'flex',
flexDirection: 'column',
minWidth: '40%',
maxHeight: window.innerHeight - 300,
overflow: 'scroll',
p: 4,
gap: 4,
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
>
<Flex justifyContent="space-between" alignItems="center">
<Text size="xl" fontWeight={600}>
{isCheckpoint || advancedAddMode === 'checkpoint'
? 'Add Checkpoint Model'
: 'Add Diffusers Model'}
</Text>
<IAIIconButton
icon={<FaTimes />}
aria-label="Close Advanced"
onClick={() => dispatch(setAdvancedAddScanModel(null))}
size="sm"
/>
</Flex>
<IAIMantineSelect
label="Model Type"
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) return;
setAdvancedAddMode(v as ManualAddMode);
if (v === 'checkpoint') {
setIsCheckpoint(true);
} else {
setIsCheckpoint(false);
}
}}
/>
{isCheckpoint ? (
<AdvancedAddCheckpoint
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
) : (
<AdvancedAddDiffusers
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
)}
</Box>
)
);
}

View File

@ -0,0 +1,25 @@
import { Flex } from '@chakra-ui/react';
import FoundModelsList from './FoundModelsList';
import ScanAdvancedAddModels from './ScanAdvancedAddModels';
import SearchFolderForm from './SearchFolderForm';
export default function ScanModels() {
return (
<Flex flexDirection="column" w="100%" gap={4}>
<SearchFolderForm />
<Flex gap={4}>
<Flex
sx={{
maxHeight: window.innerHeight - 300,
overflow: 'scroll',
gap: 4,
w: '100%',
}}
>
<FoundModelsList />
</Flex>
<ScanAdvancedAddModels />
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,139 @@
import { Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIInput from 'common/components/IAIInput';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSearch, FaSync, FaTrash } from 'react-icons/fa';
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
import {
setAdvancedAddScanModel,
setSearchFolder,
} from '../../store/modelManagerSlice';
type SearchFolderForm = {
folder: string;
};
function SearchFolderForm() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const searchFolder = useAppSelector(
(state: RootState) => state.modelmanager.searchFolder
);
const { refetch: refetchFoundModels } = useGetModelsInFolderQuery({
search_path: searchFolder ? searchFolder : '',
});
const searchFolderForm = useForm<SearchFolderForm>({
initialValues: {
folder: '',
},
});
const searchFolderFormSubmitHandler = useCallback(
(values: SearchFolderForm) => {
dispatch(setSearchFolder(values.folder));
},
[dispatch]
);
const scanAgainHandler = () => {
refetchFoundModels();
};
return (
<form
onSubmit={searchFolderForm.onSubmit((values) =>
searchFolderFormSubmitHandler(values)
)}
style={{ width: '100%' }}
>
<Flex
sx={{
w: '100%',
gap: 2,
borderRadius: 4,
alignItems: 'center',
}}
>
<Flex w="100%" alignItems="center" gap={4} minH={12}>
<Text
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'base.700',
minW: 'max-content',
_dark: { color: 'base.300' },
}}
>
Folder
</Text>
{!searchFolder ? (
<IAIInput
w="100%"
size="md"
{...searchFolderForm.getInputProps('folder')}
/>
) : (
<Flex
sx={{
w: '100%',
p: 2,
px: 4,
bg: 'base.300',
borderRadius: 4,
fontSize: 'sm',
fontWeight: 'bold',
_dark: { bg: 'base.700' },
}}
>
{searchFolder}
</Flex>
)}
</Flex>
<Flex gap={2}>
{!searchFolder ? (
<IAIIconButton
aria-label={t('modelManager.findModels')}
tooltip={t('modelManager.findModels')}
icon={<FaSearch />}
fontSize={18}
size="sm"
type="submit"
/>
) : (
<IAIIconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<FaSync />}
onClick={scanAgainHandler}
fontSize={18}
size="sm"
/>
)}
<IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
size="sm"
onClick={() => {
dispatch(setSearchFolder(null));
dispatch(setAdvancedAddScanModel(null));
}}
isDisabled={!searchFolder}
colorScheme="red"
/>
</Flex>
</Flex>
</form>
);
}
export default memo(SearchFolderForm);

View File

@ -0,0 +1,108 @@
import { Flex } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { SelectItem } from '@mantine/core';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { addToast } from 'features/system/store/systemSlice';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
const predictionSelectData: SelectItem[] = [
{ label: 'None', value: 'none' },
{ label: 'v_prediction', value: 'v_prediction' },
{ label: 'epsilon', value: 'epsilon' },
{ label: 'sample', value: 'sample' },
];
type ExtendedImportModelConfig = {
location: string;
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
};
export default function SimpleAddModels() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm<ExtendedImportModelConfig>({
initialValues: {
location: '',
prediction_type: undefined,
},
});
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
const importModelResponseBody = {
location: values.location,
prediction_type:
values.prediction_type === 'none' ? undefined : values.prediction_type,
};
importMainModel({ body: importModelResponseBody })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: 'Model Added',
status: 'success',
})
)
);
addModelForm.reset();
})
.catch((error) => {
if (error) {
console.log(error);
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}
style={{ width: '100%' }}
>
<Flex flexDirection="column" width="100%" gap={4}>
<IAIMantineTextInput
label="Model Location"
placeholder="Provide a path to a local Diffusers model, local checkpoint / safetensors model or a HuggingFace Repo ID"
w="100%"
{...addModelForm.getInputProps('location')}
/>
<IAIMantineSelect
label="Prediction Type (for Stable Diffusion 2.x Models only)"
data={predictionSelectData}
defaultValue="none"
{...addModelForm.getInputProps('prediction_type')}
/>
<IAIButton
type="submit"
isLoading={isLoading}
isDisabled={isLoading || isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</form>
);
}

View File

@ -0,0 +1,39 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import AddModels from './AddModelsPanel/AddModels';
import ScanModels from './AddModelsPanel/ScanModels';
type AddModelTabs = 'add' | 'scan';
export default function ImportModelsPanel() {
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setAddModelTab('add')}
isChecked={addModelTab == 'add'}
size="sm"
width="100%"
>
{t('modelManager.addModel')}
</IAIButton>
<IAIButton
onClick={() => setAddModelTab('scan')}
isChecked={addModelTab == 'scan'}
size="sm"
width="100%"
>
{t('modelManager.scanForModels')}
</IAIButton>
</ButtonGroup>
{addModelTab == 'add' && <AddModels />}
{addModelTab == 'scan' && <ScanModels />}
</Flex>
);
}

View File

@ -1,11 +1,4 @@
import { import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
Flex,
Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -23,7 +16,6 @@ import {
useMergeMainModelsMutation, useMergeMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types'; import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [ const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' }, { label: 'Stable Diffusion 1', value: 'sd-1' },
@ -38,13 +30,9 @@ type MergeInterpolationMethods =
export default function MergeModelsPanel() { export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery({ const { data } = useGetMainModelsQuery();
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
});
const [mergeModels, { isLoading }] = useMergeMainModelsMutation(); const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
@ -70,10 +58,10 @@ export default function MergeModelsPanel() {
}, [sd1DiffusersModels, sd2DiffusersModels]); }, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>( const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0] Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[0]
); );
const [modelTwo, setModelTwo] = useState<string | null>( const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1] Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[1]
); );
const [modelThree, setModelThree] = useState<string | null>(null); const [modelThree, setModelThree] = useState<string | null>(null);
@ -101,9 +89,9 @@ export default function MergeModelsPanel() {
modelsMap[baseModel as keyof typeof modelsMap] modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree); ).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter( const modelThreeList = Object.keys(
(model) => model !== modelOne && model !== modelTwo modelsMap[baseModel as keyof typeof modelsMap]
); ).filter((model) => model !== modelOne && model !== modelTwo);
const handleBaseModelChange = (v: string) => { const handleBaseModelChange = (v: string) => {
setBaseModel(v as BaseModelType); setBaseModel(v as BaseModelType);
@ -128,9 +116,9 @@ export default function MergeModelsPanel() {
mergedModelName !== '' ? mergedModelName : models_names.join('-'), mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha, alpha: modelMergeAlpha,
interp: modelMergeInterp, interp: modelMergeInterp,
// model_merge_save_path:
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce, force: modelMergeForce,
merge_dest_directory:
modelMergeSaveLocType === 'root' ? undefined : modelMergeCustomSaveLoc,
}; };
mergeModels({ mergeModels({
@ -230,7 +218,10 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: mode('base.100', 'base.800')(colorMode), bg: 'base.200',
_dark: {
bg: 'base.800',
},
}} }}
> >
<IAISlider <IAISlider
@ -255,7 +246,10 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: mode('base.100', 'base.800')(colorMode), bg: 'base.200',
_dark: {
bg: 'base.800',
},
}} }}
> >
<Text fontWeight={500} fontSize="sm" variant="subtext"> <Text fontWeight={500} fontSize="sm" variant="subtext">
@ -291,13 +285,16 @@ export default function MergeModelsPanel() {
</RadioGroup> </RadioGroup>
</Flex> </Flex>
{/* <Flex <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: 'base.200',
_dark: {
bg: 'base.900',
},
}} }}
> >
<Flex columnGap={4}> <Flex columnGap={4}>
@ -327,7 +324,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)} onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/> />
)} )}
</Flex> */} </Flex>
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')} label={t('modelManager.ignoreMismatch')}

View File

@ -1,33 +1,26 @@
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Badge, Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react'; import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
CheckpointModelConfigEntity, CheckpointModelConfigEntity,
useGetCheckpointConfigsQuery,
useUpdateMainModelsMutation, useUpdateMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types'; import { CheckpointModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
model: CheckpointModelConfigEntity; model: CheckpointModelConfigEntity;
}; };
@ -38,6 +31,15 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const { model } = props; const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
useEffect(() => {
if (!availableCheckpointConfigs?.includes(model.config)) {
setUseCustomConfig(true);
}
}, [availableCheckpointConfigs, model.config]);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -80,7 +82,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
) )
); );
}) })
.catch((error) => { .catch((_) => {
checkpointEditForm.reset(); checkpointEditForm.reset();
dispatch( dispatch(
addToast( addToast(
@ -113,7 +115,20 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
{MODEL_TYPE_MAP[model.base_model]} Model {MODEL_TYPE_MAP[model.base_model]} Model
</Text> </Text>
</Flex> </Flex>
<ModelConvert model={model} /> {!['sdxl', 'sdxl-refiner'].includes(model.base_model) ? (
<ModelConvert model={model} />
) : (
<Badge
sx={{
p: 2,
borderRadius: 4,
bg: 'error.200',
_dark: { bg: 'error.400' },
}}
>
Conversion Not Supported
</Badge>
)}
</Flex> </Flex>
<Divider /> <Divider />
@ -128,21 +143,24 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput <IAIMantineTextInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')} {...checkpointEditForm.getInputProps('description')}
/> />
<IAIMantineSelect <BaseModelSelect
label={t('modelManager.baseModel')} required
data={baseModelSelectData}
{...checkpointEditForm.getInputProps('base_model')} {...checkpointEditForm.getInputProps('base_model')}
/> />
<IAIMantineSelect <ModelVariantSelect
label={t('modelManager.variant')} required
data={variantSelectData}
{...checkpointEditForm.getInputProps('variant')} {...checkpointEditForm.getInputProps('variant')}
/> />
<IAIMantineTextInput <IAIMantineTextInput
required
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')} {...checkpointEditForm.getInputProps('path')}
/> />
@ -150,10 +168,27 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
label={t('modelManager.vaeLocation')} label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')} {...checkpointEditForm.getInputProps('vae')}
/> />
<IAIMantineTextInput
label={t('modelManager.config')} <Flex flexDirection="column" gap={2}>
{...checkpointEditForm.getInputProps('config')} {!useCustomConfig ? (
/> <CheckpointConfigsSelect
required
{...checkpointEditForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
</Flex>
<IAIButton <IAIButton
type="submit" type="submit"
isDisabled={isBusy || isLoading} isDisabled={isBusy || isLoading}

View File

@ -4,7 +4,6 @@ import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
@ -15,22 +14,13 @@ import {
useUpdateMainModelsMutation, useUpdateMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types'; import { DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type DiffusersModelEditProps = { type DiffusersModelEditProps = {
model: DiffusersModelConfigEntity; model: DiffusersModelConfigEntity;
}; };
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) { export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
@ -65,6 +55,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
model_name: model.model_name, model_name: model.model_name,
body: values, body: values,
}; };
updateMainModel(responseBody) updateMainModel(responseBody)
.unwrap() .unwrap()
.then((payload) => { .then((payload) => {
@ -78,7 +69,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
) )
); );
}) })
.catch((error) => { .catch((_) => {
diffusersEditForm.reset(); diffusersEditForm.reset();
dispatch( dispatch(
addToast( addToast(
@ -118,21 +109,24 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput <IAIMantineTextInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')} {...diffusersEditForm.getInputProps('description')}
/> />
<IAIMantineSelect <BaseModelSelect
label={t('modelManager.baseModel')} required
data={baseModelSelectData}
{...diffusersEditForm.getInputProps('base_model')} {...diffusersEditForm.getInputProps('base_model')}
/> />
<IAIMantineSelect <ModelVariantSelect
label={t('modelManager.variant')} required
data={variantSelectData}
{...diffusersEditForm.getInputProps('variant')} {...diffusersEditForm.getInputProps('variant')}
/> />
<IAIMantineTextInput <IAIMantineTextInput
required
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')} {...diffusersEditForm.getInputProps('path')}
/> />

Some files were not shown because too many files have changed in this diff Show More