mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into save_vram
This commit is contained in:
commit
889b77d3d6
19
README.md
19
README.md
@ -132,8 +132,10 @@ and go to http://localhost:9090.
|
|||||||
|
|
||||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||||
|
|
||||||
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are
|
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
|
||||||
not supported.
|
later versions are not supported.
|
||||||
|
Node.js also needs to be installed along with yarn (can be installed with
|
||||||
|
the command `npm install -g yarn` if needed)
|
||||||
|
|
||||||
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
||||||
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
||||||
@ -197,11 +199,18 @@ not supported.
|
|||||||
7. Launch the web server (do it every time you run InvokeAI):
|
7. Launch the web server (do it every time you run InvokeAI):
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
invokeai --web
|
invokeai-web
|
||||||
```
|
```
|
||||||
|
|
||||||
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
8. Build Node.js assets
|
||||||
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
|
||||||
|
```terminal
|
||||||
|
cd invokeai/frontend/web/
|
||||||
|
yarn vite build
|
||||||
|
```
|
||||||
|
|
||||||
|
9. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||||
|
10. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||||
|
|
||||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||||
|
@ -81,3 +81,193 @@ pytest --cov; open ./coverage/html/index.html
|
|||||||
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
||||||
|
|
||||||
--8<-- "invokeai/frontend/web/README.md"
|
--8<-- "invokeai/frontend/web/README.md"
|
||||||
|
|
||||||
|
## Developing InvokeAI in VSCode
|
||||||
|
|
||||||
|
VSCode offers some nice tools:
|
||||||
|
|
||||||
|
- python debugger
|
||||||
|
- automatic `venv` activation
|
||||||
|
- remote dev (e.g. run InvokeAI on a beefy linux desktop while you type in
|
||||||
|
comfort on your macbook)
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
You'll need the
|
||||||
|
[Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python)
|
||||||
|
and
|
||||||
|
[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance)
|
||||||
|
extensions installed first.
|
||||||
|
|
||||||
|
It's also really handy to install the `Jupyter` extensions:
|
||||||
|
|
||||||
|
- [Jupyter](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)
|
||||||
|
- [Jupyter Cell Tags](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-cell-tags)
|
||||||
|
- [Jupyter Notebook Renderers](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter-renderers)
|
||||||
|
- [Jupyter Slide Show](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-slideshow)
|
||||||
|
|
||||||
|
#### InvokeAI workspace
|
||||||
|
|
||||||
|
Creating a VSCode workspace for working on InvokeAI is highly recommended. It
|
||||||
|
can hold InvokeAI-specific settings and configs.
|
||||||
|
|
||||||
|
To make a workspace:
|
||||||
|
|
||||||
|
- Open the InvokeAI repo dir in VSCode
|
||||||
|
- `File` > `Save Workspace As` > save it _outside_ the repo
|
||||||
|
|
||||||
|
#### Default python interpreter (i.e. automatic virtual environment activation)
|
||||||
|
|
||||||
|
- Use command palette to run command
|
||||||
|
`Preferences: Open Workspace Settings (JSON)`
|
||||||
|
- Add `python.defaultInterpreterPath` to `settings`, pointing to your `venv`'s
|
||||||
|
python
|
||||||
|
|
||||||
|
Should look something like this:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
// I like to have all InvokeAI-related folders in my workspace
|
||||||
|
"folders": [
|
||||||
|
{
|
||||||
|
// repo root
|
||||||
|
"path": "InvokeAI"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// InvokeAI root dir, where `invokeai.yaml` lives
|
||||||
|
"path": "/path/to/invokeai_root"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"settings": {
|
||||||
|
// Where your InvokeAI `venv`'s python executable lives
|
||||||
|
"python.defaultInterpreterPath": "/path/to/invokeai_root/.venv/bin/python"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now when you open the VSCode integrated terminal, or do anything that needs to
|
||||||
|
run python, it will automatically be in your InvokeAI virtual environment.
|
||||||
|
|
||||||
|
Bonus: When you create a Jupyter notebook, when you run it, you'll be prompted
|
||||||
|
for the python interpreter to run in. This will default to your `venv` python,
|
||||||
|
and so you'll have access to the same python environment as the InvokeAI app.
|
||||||
|
|
||||||
|
This is _super_ handy.
|
||||||
|
|
||||||
|
#### Debugging configs with `launch.json`
|
||||||
|
|
||||||
|
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
|
||||||
|
these can be scoped to a workspace or folder.
|
||||||
|
|
||||||
|
Follow the [official guide](https://code.visualstudio.com/docs/python/debugging)
|
||||||
|
to set up your `launch.json` and try it out.
|
||||||
|
|
||||||
|
Now we can create the InvokeAI debugging configs:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
// Run the InvokeAI backend & serve the pre-built UI
|
||||||
|
"name": "InvokeAI Web",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "scripts/invokeai-web.py",
|
||||||
|
"args": [
|
||||||
|
// Your InvokeAI root dir (where `invokeai.yaml` lives)
|
||||||
|
"--root",
|
||||||
|
"/path/to/invokeai_root",
|
||||||
|
// Access the app from anywhere on your local network
|
||||||
|
"--host",
|
||||||
|
"0.0.0.0"
|
||||||
|
],
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Run the nodes-based CLI
|
||||||
|
"name": "InvokeAI CLI",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "scripts/invokeai-cli.py",
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Run tests
|
||||||
|
"name": "InvokeAI Test",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": ["--capture=no"],
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Run a single test
|
||||||
|
"name": "InvokeAI Single Test",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": [
|
||||||
|
// Change this to point to the specific test you are working on
|
||||||
|
"tests/nodes/test_invoker.py"
|
||||||
|
],
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// This is the default, useful to just run a single file
|
||||||
|
"name": "Python: File",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"justMyCode": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You'll see these configs in the debugging configs drop down. Running them will
|
||||||
|
start InvokeAI with attached debugger, in the correct environment, and work just
|
||||||
|
like the normal app.
|
||||||
|
|
||||||
|
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).
|
||||||
|
|
||||||
|
#### Remote dev
|
||||||
|
|
||||||
|
This is very easy to set up and provides the same very smooth experience as
|
||||||
|
local development. Environments and debugging, as set up above, just work,
|
||||||
|
though you'd need to recreate the workspace and debugging configs on the remote.
|
||||||
|
|
||||||
|
Consult the
|
||||||
|
[official guide](https://code.visualstudio.com/docs/remote/remote-overview) to
|
||||||
|
get it set up.
|
||||||
|
|
||||||
|
Suggest using VSCode's included settings sync so that your remote dev host has
|
||||||
|
all the same app settings and extensions automagically.
|
||||||
|
|
||||||
|
##### One remote dev gotcha
|
||||||
|
|
||||||
|
I've found the automatic port forwarding to be very flakey. You can disable it
|
||||||
|
in `Preferences: Open Remote Settings (ssh: hostname)`. Search for
|
||||||
|
`remote.autoForwardPorts` and untick the box.
|
||||||
|
|
||||||
|
To forward ports very reliably, use SSH on the remote dev client (e.g. your
|
||||||
|
macbook). Here's how to forward both backend API port (`9090`) and the frontend
|
||||||
|
live dev server port (`5173`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh \
|
||||||
|
-L 9090:localhost:9090 \
|
||||||
|
-L 5173:localhost:5173 \
|
||||||
|
user@remote-dev-host
|
||||||
|
```
|
||||||
|
|
||||||
|
The forwarding stops when you close the terminal window, so suggest to do this
|
||||||
|
_outside_ the VSCode integrated terminal in case you need to restart VSCode for
|
||||||
|
an extension update or something
|
||||||
|
|
||||||
|
Now, on your remote dev client, you can open `localhost:9090` and access the UI,
|
||||||
|
now served from the remote dev host, just the same as if it was running on the
|
||||||
|
client.
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
# Nodes Editor (Experimental Beta)
|
# Nodes Editor (Experimental)
|
||||||
|
|
||||||
|
🚨
|
||||||
|
*The node editor is experimental. We've made it accessible because we use it to develop the application, but we have not addressed the many known rough edges. It's very easy to shoot yourself in the foot, and we cannot offer support for it until it sees full release (ETA v3.1). Everything is subject to change without warning.*
|
||||||
|
🚨
|
||||||
|
|
||||||
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.
|
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.
|
||||||
|
|
||||||
|
@ -24,7 +24,8 @@ read -e -p "Tag this repo with '${VERSION}' and '${LATEST_TAG}'? [n]: " input
|
|||||||
RESPONSE=${input:='n'}
|
RESPONSE=${input:='n'}
|
||||||
if [ "$RESPONSE" == 'y' ]; then
|
if [ "$RESPONSE" == 'y' ]; then
|
||||||
|
|
||||||
if ! git tag $VERSION ; then
|
git push origin :refs/tags/$VERSION
|
||||||
|
if ! git tag -fa $VERSION ; then
|
||||||
echo "Existing/invalid tag"
|
echo "Existing/invalid tag"
|
||||||
exit -1
|
exit -1
|
||||||
fi
|
fi
|
||||||
|
@ -38,7 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
|
|||||||
echo.
|
echo.
|
||||||
echo See %INSTRUCTIONS% for more details.
|
echo See %INSTRUCTIONS% for more details.
|
||||||
echo.
|
echo.
|
||||||
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
echo FOR THE BEST USER EXPERIENCE WE SUGGEST MAXIMIZING THIS WINDOW NOW.
|
||||||
pause
|
pause
|
||||||
|
|
||||||
@rem ---------------------------- check Python version ---------------
|
@rem ---------------------------- check Python version ---------------
|
||||||
|
@ -19,7 +19,7 @@ echo 8. Open the developer console
|
|||||||
echo 9. Update InvokeAI
|
echo 9. Update InvokeAI
|
||||||
echo 10. Command-line help
|
echo 10. Command-line help
|
||||||
echo Q - Quit
|
echo Q - Quit
|
||||||
set /P choice="Please enter 1-10, Q: [2] "
|
set /P choice="Please enter 1-10, Q: [1] "
|
||||||
if not defined choice set choice=1
|
if not defined choice set choice=1
|
||||||
IF /I "%choice%" == "1" (
|
IF /I "%choice%" == "1" (
|
||||||
echo Starting the InvokeAI browser-based UI..
|
echo Starting the InvokeAI browser-based UI..
|
||||||
|
@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
@ -20,7 +21,6 @@ from invokeai.version.invokeai_version import __version__
|
|||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from ..services.restoration_services import RestorationServices
|
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.image_file_storage import DiskImageFileStorage
|
from ..services.image_file_storage import DiskImageFileStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -57,8 +57,8 @@ class ApiDependencies:
|
|||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||||
logger.debug(f'InvokeAI version {__version__}')
|
logger.debug(f"InvokeAI version {__version__}")
|
||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
@ -117,7 +117,7 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config,logger),
|
model_manager=ModelManagerService(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
@ -129,7 +129,6 @@ class ApiDependencies:
|
|||||||
),
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config, logger),
|
|
||||||
configuration=config,
|
configuration=config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
@ -13,8 +13,11 @@ from invokeai.backend import BaseModelType, ModelType
|
|||||||
from invokeai.backend.model_management.models import (
|
from invokeai.backend.model_management.models import (
|
||||||
OPENAPI_MODEL_CONFIGS,
|
OPENAPI_MODEL_CONFIGS,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
|
ModelNotFoundException,
|
||||||
|
InvalidModelException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
@ -51,8 +54,9 @@ async def list_models(
|
|||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"description" : "The model was updated successfully"},
|
responses={200: {"description" : "The model was updated successfully"},
|
||||||
|
400: {"description" : "Bad request"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
400: {"description" : "Bad request"}
|
409: {"description" : "There is already a model corresponding to the new name"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
response_model = UpdateModelResponse,
|
response_model = UpdateModelResponse,
|
||||||
@ -63,23 +67,58 @@ async def update_model(
|
|||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> UpdateModelResponse:
|
) -> UpdateModelResponse:
|
||||||
""" Add Model """
|
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# rename operation requested
|
||||||
|
if info.model_name != model_name or info.base_model != base_model:
|
||||||
|
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = model_type,
|
||||||
|
model_name = model_name,
|
||||||
|
new_name = info.model_name,
|
||||||
|
new_base = info.base_model,
|
||||||
|
)
|
||||||
|
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
|
||||||
|
# update information to support an update of attributes
|
||||||
|
model_name = info.model_name
|
||||||
|
base_model = info.base_model
|
||||||
|
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
|
||||||
|
info.path = new_info.get('path')
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
model_attributes=info.dict()
|
model_attributes=info.dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||||
except KeyError as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
@ -90,6 +129,7 @@ async def update_model(
|
|||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
|
415: {"description" : "Unrecognized file/folder format"},
|
||||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
@ -116,7 +156,7 @@ async def import_model(
|
|||||||
|
|
||||||
if not info:
|
if not info:
|
||||||
logger.error("Import failed")
|
logger.error("Import failed")
|
||||||
raise HTTPException(status_code=424)
|
raise HTTPException(status_code=415)
|
||||||
|
|
||||||
logger.info(f'Successfully imported {location}, got {info}')
|
logger.info(f'Successfully imported {location}, got {info}')
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@ -126,9 +166,12 @@ async def import_model(
|
|||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
|
||||||
except KeyError as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except InvalidModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=415)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
@ -166,57 +209,13 @@ async def add_model(
|
|||||||
model_type=info.model_type
|
model_type=info.model_type
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
except KeyError as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
@models_router.post(
|
|
||||||
"/rename/{base_model}/{model_type}/{model_name}",
|
|
||||||
operation_id="rename_model",
|
|
||||||
responses= {
|
|
||||||
201: {"description" : "The model was renamed successfully"},
|
|
||||||
404: {"description" : "The model could not be found"},
|
|
||||||
409: {"description" : "There is already a model corresponding to the new name"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
response_model=ImportModelResponse
|
|
||||||
)
|
|
||||||
async def rename_model(
|
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
|
||||||
model_name: str = Path(description="current model name"),
|
|
||||||
new_name: Optional[str] = Query(description="new model name", default=None),
|
|
||||||
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
|
||||||
) -> ImportModelResponse:
|
|
||||||
""" Rename a model"""
|
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
|
||||||
base_model = base_model,
|
|
||||||
model_type = model_type,
|
|
||||||
model_name = model_name,
|
|
||||||
new_name = new_name,
|
|
||||||
new_base = new_base,
|
|
||||||
)
|
|
||||||
logger.debug(result)
|
|
||||||
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=new_name or model_name,
|
|
||||||
base_model=new_base or base_model,
|
|
||||||
model_type=model_type
|
|
||||||
)
|
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
@ -243,9 +242,9 @@ async def delete_model(
|
|||||||
)
|
)
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
except KeyError:
|
except ModelNotFoundException as e:
|
||||||
logger.error(f"Model not found: {model_name}")
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/convert/{base_model}/{model_type}/{model_name}",
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
@ -278,8 +277,8 @@ async def convert_model(
|
|||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
model_type = model_type)
|
model_type = model_type)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except KeyError:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
@ -369,8 +368,55 @@ async def merge_models(
|
|||||||
model_type = ModelType.Main,
|
model_type = ModelType.Main,
|
||||||
)
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except KeyError:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
# The rename operation is now supported by update_model and no longer needs to be
|
||||||
|
# a standalone route.
|
||||||
|
# @models_router.post(
|
||||||
|
# "/rename/{base_model}/{model_type}/{model_name}",
|
||||||
|
# operation_id="rename_model",
|
||||||
|
# responses= {
|
||||||
|
# 201: {"description" : "The model was renamed successfully"},
|
||||||
|
# 404: {"description" : "The model could not be found"},
|
||||||
|
# 409: {"description" : "There is already a model corresponding to the new name"},
|
||||||
|
# },
|
||||||
|
# status_code=201,
|
||||||
|
# response_model=ImportModelResponse
|
||||||
|
# )
|
||||||
|
# async def rename_model(
|
||||||
|
# base_model: BaseModelType = Path(description="Base model"),
|
||||||
|
# model_type: ModelType = Path(description="The type of model"),
|
||||||
|
# model_name: str = Path(description="current model name"),
|
||||||
|
# new_name: Optional[str] = Query(description="new model name", default=None),
|
||||||
|
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||||
|
# ) -> ImportModelResponse:
|
||||||
|
# """ Rename a model"""
|
||||||
|
|
||||||
|
# logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
|
# base_model = base_model,
|
||||||
|
# model_type = model_type,
|
||||||
|
# model_name = model_name,
|
||||||
|
# new_name = new_name,
|
||||||
|
# new_base = new_base,
|
||||||
|
# )
|
||||||
|
# logger.debug(result)
|
||||||
|
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||||
|
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
# model_name=new_name or model_name,
|
||||||
|
# base_model=new_base or base_model,
|
||||||
|
# model_type=model_type
|
||||||
|
# )
|
||||||
|
# return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
# except ModelNotFoundException as e:
|
||||||
|
# logger.error(str(e))
|
||||||
|
# raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
# except ValueError as e:
|
||||||
|
# logger.error(str(e))
|
||||||
|
# raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
@ -39,6 +39,7 @@ from .invocations.baseinvocation import BaseInvocation
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import invokeai.backend.util.hotfixes
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
|
@ -54,10 +54,10 @@ from .services.invocation_services import InvocationServices
|
|||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
from .services.model_manager_service import ModelManagerService
|
from .services.model_manager_service import ModelManagerService
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.restoration_services import RestorationServices
|
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import invokeai.backend.util.hotfixes
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
@ -295,7 +295,6 @@ def invoke_cli():
|
|||||||
),
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
@ -86,10 +86,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.dict(), context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
@ -111,6 +111,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
@ -129,7 +130,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=False,
|
truncate_long_prompts=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
@ -438,7 +439,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
)
|
)
|
||||||
|
|
||||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Pass unmodified prompt to conditioning without compel processing."""
|
||||||
|
|
||||||
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
||||||
|
|
||||||
|
@ -161,13 +161,13 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context,)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
|
||||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
|
||||||
|
|
||||||
with vae_info as vae,\
|
with vae_info as vae,\
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
|
@ -83,7 +83,7 @@ def get_scheduler(
|
|||||||
scheduler_name, SCHEDULER_MAP['ddim']
|
scheduler_name, SCHEDULER_MAP['ddim']
|
||||||
)
|
)
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict()
|
**scheduler_info.dict(), context=context,
|
||||||
)
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
@ -270,6 +270,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
model_name=control_info.control_model.model_name,
|
model_name=control_info.control_model.model_name,
|
||||||
model_type=ModelType.ControlNet,
|
model_type=ModelType.ControlNet,
|
||||||
base_model=control_info.control_model.base_model,
|
base_model=control_info.control_model.base_model,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -321,14 +322,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"})
|
**lora.dict(exclude={"weight"}), context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack,\
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
@ -414,14 +415,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"})
|
**lora.dict(exclude={"weight"}), context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack,\
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
@ -506,7 +507,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
@ -687,7 +688,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from typing import Literal
|
from os.path import exists
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic.fields import Field
|
import numpy as np
|
||||||
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||||
@ -55,3 +57,41 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||||
|
|
||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
|
||||||
|
|
||||||
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
|
'''Loads prompts from a text file'''
|
||||||
|
# fmt: off
|
||||||
|
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
file_path: str = Field(description="Path to prompt text file")
|
||||||
|
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
|
||||||
|
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||||
|
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||||
|
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
@validator("file_path")
|
||||||
|
def file_path_exists(cls, v):
|
||||||
|
if not exists(v):
|
||||||
|
raise ValueError(FileNotFoundError)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
|
||||||
|
prompts = []
|
||||||
|
start_line -= 1
|
||||||
|
end_line = start_line + max_prompts
|
||||||
|
if max_prompts <= 0:
|
||||||
|
end_line = np.iinfo(np.int32).max
|
||||||
|
with open(file_path) as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
if i >= start_line and i < end_line:
|
||||||
|
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
|
||||||
|
if i >= end_line:
|
||||||
|
break
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||||
|
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
|
||||||
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
|
||||||
"""Restores faces in an image."""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["restore_face"] = "restore_face"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: Optional[ImageField] = Field(description="The input image")
|
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["restoration", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=None,
|
|
||||||
strength=self.strength, # GFPGAN strength
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
|
||||||
# TODO: can this return multiple results?
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=results[0][0],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
@ -1,48 +1,112 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
from pathlib import Path, PosixPath
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Union, cast
|
||||||
|
|
||||||
|
import cv2 as cv
|
||||||
|
import numpy as np
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from PIL import Image
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
# TODO: Populate this from disk?
|
||||||
|
# TODO: Use model manager to load?
|
||||||
|
REALESRGAN_MODELS = Literal[
|
||||||
|
"RealESRGAN_x4plus.pth",
|
||||||
|
"RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
]
|
||||||
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
|
||||||
"""Upscales an image."""
|
|
||||||
|
|
||||||
# fmt: off
|
class RealESRGANInvocation(BaseInvocation):
|
||||||
type: Literal["upscale"] = "upscale"
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
# Inputs
|
type: Literal["realesrgan"] = "realesrgan"
|
||||||
image: Optional[ImageField] = Field(description="The input image", default=None)
|
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
model_name: REALESRGAN_MODELS = Field(
|
||||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||||
# fmt: on
|
)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["upscaling", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
models_path = context.services.configuration.models_path
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=(self.level, self.strength),
|
rrdbnet_model = None
|
||||||
strength=0.0, # GFPGAN strength
|
netscale = None
|
||||||
save_original=False,
|
esrgan_model_path = None
|
||||||
image_callback=None,
|
|
||||||
|
if self.model_name in [
|
||||||
|
"RealESRGAN_x4plus.pth",
|
||||||
|
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
]:
|
||||||
|
# x4 RRDBNet model
|
||||||
|
rrdbnet_model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
netscale = 4
|
||||||
|
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
|
||||||
|
# x4 RRDBNet model, 6 blocks
|
||||||
|
rrdbnet_model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=6, # 6 blocks
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
netscale = 4
|
||||||
|
# TODO: add x2 models handling?
|
||||||
|
# elif self.model_name in ["RealESRGAN_x2plus"]:
|
||||||
|
# # x2 RRDBNet model
|
||||||
|
# model = RRDBNet(
|
||||||
|
# num_in_ch=3,
|
||||||
|
# num_out_ch=3,
|
||||||
|
# num_feat=64,
|
||||||
|
# num_block=23,
|
||||||
|
# num_grow_ch=32,
|
||||||
|
# scale=2,
|
||||||
|
# )
|
||||||
|
# model_path = Path()
|
||||||
|
# netscale = 2
|
||||||
|
else:
|
||||||
|
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||||
|
context.services.logger.error(msg)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||||
|
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=netscale,
|
||||||
|
model_path=str(models_path / esrgan_model_path),
|
||||||
|
model=rrdbnet_model,
|
||||||
|
half=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
# TODO: can this return multiple results?
|
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||||
|
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
|
||||||
|
# upscaling, you'll need to add a resize node after this one.
|
||||||
|
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||||
|
|
||||||
|
# back to PIL
|
||||||
|
pil_image = Image.fromarray(
|
||||||
|
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
|
||||||
|
).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=pil_image,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
@ -271,13 +271,13 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(self)->List[str]:
|
def _excluded(self)->List[str]:
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ['type','initconf']
|
return ['type','initconf']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(self)->List[str]:
|
def _excluded_from_yaml(self)->List[str]:
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model']
|
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file_encoding = 'utf-8'
|
env_file_encoding = 'utf-8'
|
||||||
@ -366,7 +366,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
|
||||||
|
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
|
@ -105,8 +105,6 @@ class EventServiceBase:
|
|||||||
def emit_model_load_started (
|
def emit_model_load_started (
|
||||||
self,
|
self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
@ -117,8 +115,6 @@ class EventServiceBase:
|
|||||||
event_name="model_load_started",
|
event_name="model_load_started",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -129,8 +125,6 @@ class EventServiceBase:
|
|||||||
def emit_model_load_completed(
|
def emit_model_load_completed(
|
||||||
self,
|
self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
@ -142,12 +136,12 @@ class EventServiceBase:
|
|||||||
event_name="model_load_completed",
|
event_name="model_load_completed",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
hash=model_info.hash,
|
||||||
|
location=model_info.location,
|
||||||
|
precision=str(model_info.precision),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -10,10 +10,9 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from invokeai.app.services.restoration_services import RestorationServices
|
|
||||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
from invokeai.app.services.config import InvokeAISettings
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||||
|
|
||||||
@ -24,7 +23,7 @@ class InvocationServices:
|
|||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAISettings"
|
configuration: "InvokeAIAppConfig"
|
||||||
events: "EventServiceBase"
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
@ -34,13 +33,12 @@ class InvocationServices:
|
|||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
restoration: "RestorationServices"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAISettings",
|
configuration: "InvokeAIAppConfig",
|
||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
@ -50,7 +48,6 @@ class InvocationServices:
|
|||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
restoration: "RestorationServices",
|
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
@ -65,4 +62,3 @@ class InvocationServices:
|
|||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.restoration = restoration
|
|
||||||
|
@ -18,6 +18,7 @@ from invokeai.backend.model_management import (
|
|||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
ModelMerger,
|
ModelMerger,
|
||||||
MergeInterpolationMethod,
|
MergeInterpolationMethod,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_search import FindModels
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
|
|
||||||
@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
KeyErrorException if the name does not already exist.
|
ModelNotFoundException if the name does not already exist.
|
||||||
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
@ -338,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
context: Optional[InvocationContext] = None,
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
@ -346,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
part (such as the vae) of a diffusers mode.
|
part (such as the vae) of a diffusers mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if we are called from within a node, then we get to emit
|
# we can emit model loading events if we are executing with access to the invocation context
|
||||||
# load start and complete events
|
if context:
|
||||||
if node and context:
|
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
node=node,
|
|
||||||
context=context,
|
context=context,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -365,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel,
|
submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if node and context:
|
if context:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
node=node,
|
|
||||||
context=context,
|
context=context,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -451,14 +448,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
KeyError exception if the name does not already exist.
|
ModelNotFoundException exception if the name does not already exist.
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'update model {model_name}')
|
self.logger.debug(f'update model {model_name}')
|
||||||
if not self.model_exists(model_name, base_model, model_type):
|
if not self.model_exists(model_name, base_model, model_type):
|
||||||
raise KeyError(f"Unknown model {model_name}")
|
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||||
|
|
||||||
def del_model(
|
def del_model(
|
||||||
@ -509,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
node,
|
|
||||||
context,
|
context,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: SubModelType,
|
submodel: Optional[SubModelType] = None,
|
||||||
model_info: Optional[ModelInfo] = None,
|
model_info: Optional[ModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||||
raise CanceledException()
|
raise CanceledException()
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
|
||||||
if model_info:
|
if model_info:
|
||||||
context.services.events.emit_model_load_completed(
|
context.services.events.emit_model_load_completed(
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
node=node.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -535,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
context.services.events.emit_model_load_started(
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
node=node.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -1,113 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import torch
|
|
||||||
from typing import types
|
|
||||||
from ...backend.restoration import Restoration
|
|
||||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
|
||||||
|
|
||||||
# This should be a real base class for postprocessing functions,
|
|
||||||
# but right now we just instantiate the existing gfpgan, esrgan
|
|
||||||
# and codeformer functions.
|
|
||||||
class RestorationServices:
|
|
||||||
'''Face restoration and upscaling'''
|
|
||||||
|
|
||||||
def __init__(self,args,logger:types.ModuleType):
|
|
||||||
try:
|
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
if args.restore or args.esrgan:
|
|
||||||
restoration = Restoration()
|
|
||||||
# TODO: redo for new model structure
|
|
||||||
if False and args.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
|
||||||
args.gfpgan_model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Face restoration disabled")
|
|
||||||
if False and args.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
logger.info("Upscaling disabled")
|
|
||||||
else:
|
|
||||||
logger.info("Face restoration and upscaling disabled")
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
|
||||||
self.device = torch.device(choose_torch_device())
|
|
||||||
self.gfpgan = gfpgan
|
|
||||||
self.codeformer = codeformer
|
|
||||||
self.esrgan = esrgan
|
|
||||||
self.logger = logger
|
|
||||||
self.logger.info('Face restoration initialized')
|
|
||||||
|
|
||||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
|
||||||
# esrgan upscaling
|
|
||||||
# TO DO: refactor into separate methods
|
|
||||||
def upscale_and_reconstruct(
|
|
||||||
self,
|
|
||||||
image_list,
|
|
||||||
facetool="gfpgan",
|
|
||||||
upscale=None,
|
|
||||||
upscale_denoise_str=0.75,
|
|
||||||
strength=0.0,
|
|
||||||
codeformer_fidelity=0.75,
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
prefix=None,
|
|
||||||
):
|
|
||||||
results = []
|
|
||||||
for r in image_list:
|
|
||||||
image, seed = r
|
|
||||||
try:
|
|
||||||
if strength > 0:
|
|
||||||
if self.gfpgan is not None or self.codeformer is not None:
|
|
||||||
if facetool == "gfpgan":
|
|
||||||
if self.gfpgan is None:
|
|
||||||
self.logger.info(
|
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
|
||||||
if facetool == "codeformer":
|
|
||||||
if self.codeformer is None:
|
|
||||||
self.logger.info(
|
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cf_device = (
|
|
||||||
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
|
||||||
)
|
|
||||||
image = self.codeformer.process(
|
|
||||||
image=image,
|
|
||||||
strength=strength,
|
|
||||||
device=cf_device,
|
|
||||||
seed=seed,
|
|
||||||
fidelity=codeformer_fidelity,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.info("Face Restoration is disabled.")
|
|
||||||
if upscale is not None:
|
|
||||||
if self.esrgan is not None:
|
|
||||||
if len(upscale) < 2:
|
|
||||||
upscale.append(0.75)
|
|
||||||
image = self.esrgan.process(
|
|
||||||
image,
|
|
||||||
upscale[1],
|
|
||||||
seed,
|
|
||||||
int(upscale[0]),
|
|
||||||
denoise_str=upscale_denoise_str,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.info(
|
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if image_callback is not None:
|
|
||||||
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
|
||||||
else:
|
|
||||||
r[0] = image
|
|
||||||
|
|
||||||
results.append([image, seed])
|
|
||||||
|
|
||||||
return results
|
|
@ -30,8 +30,6 @@ from huggingface_hub import login as hf_hub_login
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
|
||||||
CLIPSegForImageSegmentation,
|
|
||||||
CLIPTextModel,
|
CLIPTextModel,
|
||||||
CLIPTokenizer,
|
CLIPTokenizer,
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
@ -45,7 +43,6 @@ from invokeai.app.services.config import (
|
|||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
SingleSelectColumns,
|
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
IntTitleSlider,
|
IntTitleSlider,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
@ -72,7 +69,6 @@ transformers.logging.set_verbosity_error()
|
|||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
|
||||||
|
|
||||||
Default_config_file = config.model_conf_path
|
Default_config_file = config.model_conf_path
|
||||||
SD_Configs = config.legacy_conf_path
|
SD_Configs = config.legacy_conf_path
|
||||||
@ -226,64 +222,30 @@ def download_conversion_models():
|
|||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing models from RealESRGAN...")
|
logger.info("Installing RealESRGAN models...")
|
||||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
URLs = [
|
||||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
dict(
|
||||||
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
description = "RealESRGAN_x4plus.pth",
|
||||||
|
),
|
||||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
dict(
|
||||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
description = "RealESRGAN_x4plus_anime_6B.pth",
|
||||||
def download_gfpgan():
|
),
|
||||||
logger.info("Installing GFPGAN models...")
|
dict(
|
||||||
for model in (
|
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
[
|
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
|
),
|
||||||
],
|
]
|
||||||
[
|
for model in URLs:
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||||
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
|
|
||||||
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
|
|
||||||
],
|
|
||||||
):
|
|
||||||
model_url, model_dest = model[0], config.root_path / model[1]
|
|
||||||
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_codeformer():
|
|
||||||
logger.info("Installing CodeFormer model file...")
|
|
||||||
model_url = (
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
|
||||||
)
|
|
||||||
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
|
|
||||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_clipseg():
|
|
||||||
logger.info("Installing clipseg model for text-based masking...")
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
|
||||||
try:
|
|
||||||
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
|
||||||
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
|
||||||
except Exception:
|
|
||||||
logger.info("Error installing clipseg model:")
|
|
||||||
logger.info(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
def download_support_models():
|
def download_support_models():
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
download_gfpgan()
|
|
||||||
download_codeformer()
|
|
||||||
download_clipseg()
|
|
||||||
download_conversion_models()
|
download_conversion_models()
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
@ -666,7 +628,7 @@ def run_console_ui(
|
|||||||
|
|
||||||
# The third argument is needed in the Windows 11 environment to
|
# The third argument is needed in the Windows 11 environment to
|
||||||
# launch a console window running this program.
|
# launch a console window running this program.
|
||||||
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
|
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||||
|
|
||||||
# the install-models application spawns a subprocess to install
|
# the install-models application spawns a subprocess to install
|
||||||
# models, and will crash unless this is set before running.
|
# models, and will crash unless this is set before running.
|
||||||
@ -743,7 +705,7 @@ def migrate_if_needed(opt: Namespace, root: Path)->bool:
|
|||||||
old_init_file = root / 'invokeai.init'
|
old_init_file = root / 'invokeai.init'
|
||||||
new_init_file = root / 'invokeai.yaml'
|
new_init_file = root / 'invokeai.yaml'
|
||||||
old_hub = root / 'models/hub'
|
old_hub = root / 'models/hub'
|
||||||
migration_needed = old_init_file.exists() and not new_init_file.exists() or old_hub.exists()
|
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||||
|
|
||||||
if migration_needed:
|
if migration_needed:
|
||||||
if opt.yes_to_all or \
|
if opt.yes_to_all or \
|
||||||
@ -858,9 +820,9 @@ def main():
|
|||||||
download_support_models()
|
download_support_models()
|
||||||
|
|
||||||
if opt.skip_sd_weights:
|
if opt.skip_sd_weights:
|
||||||
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
|
||||||
elif models_to_download:
|
elif models_to_download:
|
||||||
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
|
||||||
process_and_execute(opt, models_to_download)
|
process_and_execute(opt, models_to_download)
|
||||||
|
|
||||||
postscript(errors=errors)
|
postscript(errors=errors)
|
||||||
|
@ -117,6 +117,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = self.mgr.list_models()
|
installed_models = self.mgr.list_models()
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md['base_model']
|
base = md['base_model']
|
||||||
model_type = md['model_type']
|
model_type = md['model_type']
|
||||||
@ -134,6 +135,12 @@ class ModelInstall(object):
|
|||||||
)
|
)
|
||||||
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
||||||
|
|
||||||
|
def list_models(self, model_type):
|
||||||
|
installed = self.mgr.list_models(model_type=model_type)
|
||||||
|
print(f'Installed models of type `{model_type}`:')
|
||||||
|
for i in installed:
|
||||||
|
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
||||||
|
|
||||||
def starter_models(self)->Set[str]:
|
def starter_models(self)->Set[str]:
|
||||||
models = set()
|
models = set()
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
@ -205,7 +212,7 @@ class ModelInstall(object):
|
|||||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
models_installed.update(self._install_path(path))
|
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
elif path.is_dir():
|
elif path.is_dir():
|
||||||
|
@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management
|
|||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
|
||||||
|
@ -552,7 +552,7 @@ class ModelManager(object):
|
|||||||
model_config = self.models.get(model_key)
|
model_config = self.models.get(model_key)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
self.logger.error(f'Unknown model {model_name}')
|
self.logger.error(f'Unknown model {model_name}')
|
||||||
raise KeyError(f'Unknown model {model_name}')
|
raise ModelNotFoundException(f'Unknown model {model_name}')
|
||||||
|
|
||||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
if base_model is not None and cur_base_model != base_model:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
@ -568,6 +568,9 @@ class ModelManager(object):
|
|||||||
model_type=cur_model_type,
|
model_type=cur_model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# expose paths as absolute to help web UI
|
||||||
|
if path := model_dict.get('path'):
|
||||||
|
model_dict['path'] = str(self.app_config.root_path / path)
|
||||||
models.append(model_dict)
|
models.append(model_dict)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
@ -596,7 +599,7 @@ class ModelManager(object):
|
|||||||
model_cfg = self.models.pop(model_key, None)
|
model_cfg = self.models.pop(model_key, None)
|
||||||
|
|
||||||
if model_cfg is None:
|
if model_cfg is None:
|
||||||
raise KeyError(f"Unknown model {model_key}")
|
raise ModelNotFoundException(f"Unknown model {model_key}")
|
||||||
|
|
||||||
# note: it not garantie to release memory(model can has other references)
|
# note: it not garantie to release memory(model can has other references)
|
||||||
cache_ids = self.cache_keys.pop(model_key, [])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
@ -635,6 +638,10 @@ class ModelManager(object):
|
|||||||
The returned dict has the same format as the dict returned by
|
The returned dict has the same format as the dict returned by
|
||||||
model_info().
|
model_info().
|
||||||
"""
|
"""
|
||||||
|
# relativize paths as they go in - this makes it easier to move the root directory around
|
||||||
|
if path := model_attributes.get('path'):
|
||||||
|
if Path(path).is_relative_to(self.app_config.root_path):
|
||||||
|
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
@ -689,7 +696,7 @@ class ModelManager(object):
|
|||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
model_cfg = self.models.get(model_key, None)
|
model_cfg = self.models.get(model_key, None)
|
||||||
if not model_cfg:
|
if not model_cfg:
|
||||||
raise KeyError(f"Unknown model: {model_key}")
|
raise ModelNotFoundException(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
old_path = self.app_config.root_path / model_cfg.path
|
old_path = self.app_config.root_path / model_cfg.path
|
||||||
new_name = new_name or model_name
|
new_name = new_name or model_name
|
||||||
@ -700,7 +707,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# if this is a model file/directory that we manage ourselves, we need to move it
|
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||||
if old_path.is_relative_to(self.app_config.models_path):
|
if old_path.is_relative_to(self.app_config.models_path):
|
||||||
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
|
new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
|
||||||
move(old_path, new_path)
|
move(old_path, new_path)
|
||||||
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
@ -908,7 +915,6 @@ class ModelManager(object):
|
|||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
|
|
||||||
|
|
||||||
class ScanAndImport(ModelSearch):
|
class ScanAndImport(ModelSearch):
|
||||||
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||||
super().__init__(directories, logger)
|
super().__init__(directories, logger)
|
||||||
@ -965,7 +971,7 @@ class ModelManager(object):
|
|||||||
that model.
|
that model.
|
||||||
|
|
||||||
May return the following exceptions:
|
May return the following exceptions:
|
||||||
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
|
||||||
- ValueError - a corresponding model already exists
|
- ValueError - a corresponding model already exists
|
||||||
'''
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
|
@ -12,6 +12,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, ModelVariantType,
|
BaseModelType, ModelType, ModelVariantType,
|
||||||
SchedulerPredictionType, SilenceWarnings,
|
SchedulerPredictionType, SilenceWarnings,
|
||||||
|
InvalidModelException
|
||||||
)
|
)
|
||||||
from .models.base import read_checkpoint_meta
|
from .models.base import read_checkpoint_meta
|
||||||
|
|
||||||
@ -61,7 +62,7 @@ class ModelProbe(object):
|
|||||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||||
else:
|
else:
|
||||||
raise ValueError("model parameter {model} is neither a Path, nor a model")
|
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe(cls,
|
def probe(cls,
|
||||||
@ -141,7 +142,7 @@ class ModelProbe(object):
|
|||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
raise ValueError(f"Unable to determine model type for {model_path}")
|
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||||
@ -171,7 +172,7 @@ class ModelProbe(object):
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
# give up
|
# give up
|
||||||
raise ValueError(f"Unable to determine model type for {folder_path}")
|
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||||
@ -240,7 +241,7 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
elif in_channels == 4:
|
elif in_channels == 4:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
||||||
|
|
||||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
@ -254,7 +255,7 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
|
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
raise ValueError("Cannot determine base type")
|
raise InvalidModelException("Cannot determine base type")
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||||
type = self.get_base_type()
|
type = self.get_base_type()
|
||||||
@ -335,7 +336,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
elif self.checkpoint_path and self.helper:
|
elif self.checkpoint_path and self.helper:
|
||||||
return self.helper(self.checkpoint_path)
|
return self.helper(self.checkpoint_path)
|
||||||
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
|
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# classes for probing folders
|
# classes for probing folders
|
||||||
@ -371,7 +372,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
elif unet_conf['cross_attention_dim'] == 2048:
|
elif unet_conf['cross_attention_dim'] == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown base model for {self.folder_path}')
|
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||||
if self.model:
|
if self.model:
|
||||||
@ -428,7 +429,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
|||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
config_file = self.folder_path / 'config.json'
|
config_file = self.folder_path / 'config.json'
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Cannot determine base type for {self.folder_path}")
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||||
with open(config_file,'r') as file:
|
with open(config_file,'r') as file:
|
||||||
config = json.load(file)
|
config = json.load(file)
|
||||||
# no obvious way to distinguish between sd2-base and sd2-768
|
# no obvious way to distinguish between sd2-base and sd2-768
|
||||||
@ -445,7 +446,7 @@ class LoRAFolderProbe(FolderProbeBase):
|
|||||||
model_file = base_file
|
model_file = base_file
|
||||||
break
|
break
|
||||||
if not model_file:
|
if not model_file:
|
||||||
raise ValueError('Unknown LoRA format encountered')
|
raise InvalidModelException('Unknown LoRA format encountered')
|
||||||
return LoRACheckpointProbe(model_file,None).get_base_type()
|
return LoRACheckpointProbe(model_file,None).get_base_type()
|
||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
|
@ -68,7 +68,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
return None # diffusers-ti
|
return None # diffusers-ti
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
raise InvalidModelException(f"Not a valid model: {path}")
|
||||||
|
@ -16,6 +16,7 @@ from .base import (
|
|||||||
calc_model_size_by_data,
|
calc_model_size_by_data,
|
||||||
classproperty,
|
classproperty,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
"""
|
|
||||||
Initialization file for the invokeai.backend.restoration package
|
|
||||||
"""
|
|
||||||
from .base import Restoration
|
|
@ -1,45 +0,0 @@
|
|||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Restoration:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_face_restore_models(
|
|
||||||
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
|
|
||||||
):
|
|
||||||
# Load GFPGAN
|
|
||||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
|
||||||
if gfpgan.gfpgan_model_exists:
|
|
||||||
logger.info("GFPGAN Initialized")
|
|
||||||
else:
|
|
||||||
logger.info("GFPGAN Disabled")
|
|
||||||
gfpgan = None
|
|
||||||
|
|
||||||
# Load CodeFormer
|
|
||||||
codeformer = self.load_codeformer()
|
|
||||||
if codeformer.codeformer_model_exists:
|
|
||||||
logger.info("CodeFormer Initialized")
|
|
||||||
else:
|
|
||||||
logger.info("CodeFormer Disabled")
|
|
||||||
codeformer = None
|
|
||||||
|
|
||||||
return gfpgan, codeformer
|
|
||||||
|
|
||||||
# Face Restore Models
|
|
||||||
def load_gfpgan(self, gfpgan_model_path):
|
|
||||||
from .gfpgan import GFPGAN
|
|
||||||
|
|
||||||
return GFPGAN(gfpgan_model_path)
|
|
||||||
|
|
||||||
def load_codeformer(self):
|
|
||||||
from .codeformer import CodeFormerRestoration
|
|
||||||
|
|
||||||
return CodeFormerRestoration()
|
|
||||||
|
|
||||||
# Upscale Models
|
|
||||||
def load_esrgan(self, esrgan_bg_tile=400):
|
|
||||||
from .realesrgan import ESRGAN
|
|
||||||
|
|
||||||
esrgan = ESRGAN(esrgan_bg_tile)
|
|
||||||
logger.info("ESRGAN Initialized")
|
|
||||||
return esrgan
|
|
@ -1,120 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
pretrained_model_url = (
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CodeFormerRestoration:
|
|
||||||
def __init__(
|
|
||||||
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
|
||||||
codeformer_dir = self.globals.root_dir / codeformer_dir
|
|
||||||
self.model_path = codeformer_dir / codeformer_model_path
|
|
||||||
self.codeformer_model_exists = self.model_path.exists()
|
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
|
||||||
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
|
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
from basicsr.utils import img2tensor, tensor2img
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms.functional import normalize
|
|
||||||
|
|
||||||
from .codeformer_arch import CodeFormer
|
|
||||||
|
|
||||||
cf_class = CodeFormer
|
|
||||||
|
|
||||||
cf = cf_class(
|
|
||||||
dim_embd=512,
|
|
||||||
codebook_size=1024,
|
|
||||||
n_head=8,
|
|
||||||
n_layers=9,
|
|
||||||
connect_list=["32", "64", "128", "256"],
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# note that this file should already be downloaded and cached at
|
|
||||||
# this point
|
|
||||||
checkpoint_path = load_file_from_url(
|
|
||||||
url=pretrained_model_url,
|
|
||||||
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
|
||||||
progress=True,
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(checkpoint_path)["params_ema"]
|
|
||||||
cf.load_state_dict(checkpoint)
|
|
||||||
cf.eval()
|
|
||||||
|
|
||||||
image = image.convert("RGB")
|
|
||||||
# Codeformer expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
face_helper = FaceRestoreHelper(
|
|
||||||
upscale_factor=1,
|
|
||||||
use_parse=True,
|
|
||||||
device=device,
|
|
||||||
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
|
|
||||||
)
|
|
||||||
face_helper.clean_all()
|
|
||||||
face_helper.read_image(bgr_image_array)
|
|
||||||
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
|
||||||
face_helper.align_warp_face()
|
|
||||||
|
|
||||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
|
||||||
cropped_face_t = img2tensor(
|
|
||||||
cropped_face / 255.0, bgr2rgb=True, float32=True
|
|
||||||
)
|
|
||||||
normalize(
|
|
||||||
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
|
|
||||||
)
|
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
|
|
||||||
restored_face = tensor2img(
|
|
||||||
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
|
|
||||||
)
|
|
||||||
del output
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except RuntimeError as error:
|
|
||||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
|
||||||
restored_face = cropped_face
|
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
|
||||||
face_helper.add_restored_face(restored_face)
|
|
||||||
|
|
||||||
face_helper.get_inverse_affine(None)
|
|
||||||
|
|
||||||
restored_img = face_helper.paste_faces_to_input_image()
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(restored_img[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if restored_img.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
cf = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,325 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from .vqgan_arch import *
|
|
||||||
|
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
|
||||||
"""Calculate mean and std for adaptive_instance_normalization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feat (Tensor): 4D tensor.
|
|
||||||
eps (float): A small value added to the variance to avoid
|
|
||||||
divide-by-zero. Default: 1e-5.
|
|
||||||
"""
|
|
||||||
size = feat.size()
|
|
||||||
assert len(size) == 4, "The input feature should be 4D tensor."
|
|
||||||
b, c = size[:2]
|
|
||||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
|
||||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
|
||||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
|
||||||
return feat_mean, feat_std
|
|
||||||
|
|
||||||
|
|
||||||
def adaptive_instance_normalization(content_feat, style_feat):
|
|
||||||
"""Adaptive instance normalization.
|
|
||||||
|
|
||||||
Adjust the reference features to have the similar color and illuminations
|
|
||||||
as those in the degradate features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content_feat (Tensor): The reference feature.
|
|
||||||
style_feat (Tensor): The degradate features.
|
|
||||||
"""
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = calc_mean_std(content_feat)
|
|
||||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
|
|
||||||
size
|
|
||||||
)
|
|
||||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.zeros(
|
|
||||||
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
not_mask = ~mask
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack(
|
|
||||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos_y = torch.stack(
|
|
||||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerSALayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model - MLP
|
|
||||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(embed_dim)
|
|
||||||
self.norm2 = nn.LayerNorm(embed_dim)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
tgt,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None,
|
|
||||||
):
|
|
||||||
# self attention
|
|
||||||
tgt2 = self.norm1(tgt)
|
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
||||||
tgt2 = self.self_attn(
|
|
||||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
|
||||||
)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
|
|
||||||
# ffn
|
|
||||||
tgt2 = self.norm2(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
|
|
||||||
class Fuse_sft_block(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.encode_enc = ResBlock(2 * in_ch, out_ch)
|
|
||||||
|
|
||||||
self.scale = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.shift = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, enc_feat, dec_feat, w=1):
|
|
||||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
|
||||||
scale = self.scale(enc_feat)
|
|
||||||
shift = self.shift(enc_feat)
|
|
||||||
residual = w * (dec_feat * scale + shift)
|
|
||||||
out = dec_feat + residual
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class CodeFormer(VQAutoEncoder):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_embd=512,
|
|
||||||
n_head=8,
|
|
||||||
n_layers=9,
|
|
||||||
codebook_size=1024,
|
|
||||||
latent_size=256,
|
|
||||||
connect_list=["32", "64", "128", "256"],
|
|
||||||
fix_modules=["quantize", "generator"],
|
|
||||||
):
|
|
||||||
super(CodeFormer, self).__init__(
|
|
||||||
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if fix_modules is not None:
|
|
||||||
for module in fix_modules:
|
|
||||||
for param in getattr(self, module).parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
self.connect_list = connect_list
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.dim_embd = dim_embd
|
|
||||||
self.dim_mlp = dim_embd * 2
|
|
||||||
|
|
||||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
|
||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
|
||||||
|
|
||||||
# transformer
|
|
||||||
self.ft_layers = nn.Sequential(
|
|
||||||
*[
|
|
||||||
TransformerSALayer(
|
|
||||||
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
|
|
||||||
)
|
|
||||||
for _ in range(self.n_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# logits_predict head
|
|
||||||
self.idx_pred_layer = nn.Sequential(
|
|
||||||
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.channels = {
|
|
||||||
"16": 512,
|
|
||||||
"32": 256,
|
|
||||||
"64": 256,
|
|
||||||
"128": 128,
|
|
||||||
"256": 128,
|
|
||||||
"512": 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
# after second residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_encoder_block = {
|
|
||||||
"512": 2,
|
|
||||||
"256": 5,
|
|
||||||
"128": 8,
|
|
||||||
"64": 11,
|
|
||||||
"32": 14,
|
|
||||||
"16": 18,
|
|
||||||
}
|
|
||||||
# after first residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_generator_block = {
|
|
||||||
"16": 6,
|
|
||||||
"32": 9,
|
|
||||||
"64": 12,
|
|
||||||
"128": 15,
|
|
||||||
"256": 18,
|
|
||||||
"512": 21,
|
|
||||||
}
|
|
||||||
|
|
||||||
# fuse_convs_dict
|
|
||||||
self.fuse_convs_dict = nn.ModuleDict()
|
|
||||||
for f_size in self.connect_list:
|
|
||||||
in_ch = self.channels[f_size]
|
|
||||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
|
|
||||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
|
||||||
# ################### Encoder #####################
|
|
||||||
enc_feat_dict = {}
|
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in out_list:
|
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
|
||||||
|
|
||||||
lq_feat = x
|
|
||||||
# ################# Transformer ###################
|
|
||||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
|
||||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
|
|
||||||
# BCHW -> BC(HW) -> (HW)BC
|
|
||||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
|
|
||||||
query_emb = feat_emb
|
|
||||||
# Transformer encoder
|
|
||||||
for layer in self.ft_layers:
|
|
||||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
|
||||||
|
|
||||||
# output logits
|
|
||||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
|
||||||
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
|
|
||||||
|
|
||||||
if code_only: # for training stage II
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return logits, lq_feat
|
|
||||||
|
|
||||||
# ################# Quantization ###################
|
|
||||||
# if self.training:
|
|
||||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
|
||||||
# # b(hw)c -> bc(hw) -> bchw
|
|
||||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
|
||||||
# ------------
|
|
||||||
soft_one_hot = F.softmax(logits, dim=2)
|
|
||||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
|
||||||
quant_feat = self.quantize.get_codebook_feat(
|
|
||||||
top_idx, shape=[x.shape[0], 16, 16, 256]
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
|
||||||
|
|
||||||
if detach_16:
|
|
||||||
quant_feat = quant_feat.detach() # for training stage III
|
|
||||||
if adain:
|
|
||||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
|
||||||
|
|
||||||
# ################## Generator ####################
|
|
||||||
x = quant_feat
|
|
||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in fuse_list: # fuse after i-th block
|
|
||||||
f_size = str(x.shape[-1])
|
|
||||||
if w > 0:
|
|
||||||
x = self.fuse_convs_dict[f_size](
|
|
||||||
enc_feat_dict[f_size].detach(), x, w
|
|
||||||
)
|
|
||||||
out = x
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return out, logits, lq_feat
|
|
@ -1,84 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
class GFPGAN:
|
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
|
||||||
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
|
||||||
self.model_path = gfpgan_model_path
|
|
||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
|
||||||
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def model_exists(self):
|
|
||||||
return os.path.isfile(self.model_path)
|
|
||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
cwd = os.getcwd()
|
|
||||||
os.chdir(self.globals.root_dir / 'models')
|
|
||||||
try:
|
|
||||||
from gfpgan import GFPGANer
|
|
||||||
|
|
||||||
self.gfpgan = GFPGANer(
|
|
||||||
model_path=self.model_path,
|
|
||||||
upscale=1,
|
|
||||||
arch="clean",
|
|
||||||
channel_multiplier=2,
|
|
||||||
bg_upsampler=None,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
os.chdir(cwd)
|
|
||||||
|
|
||||||
if self.gfpgan is None:
|
|
||||||
logger.warning("WARNING: GFPGAN not initialized.")
|
|
||||||
logger.warning(
|
|
||||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
# GFPGAN expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
_, _, restored_img = self.gfpgan.enhance(
|
|
||||||
bgr_image_array,
|
|
||||||
has_aligned=False,
|
|
||||||
only_center_face=False,
|
|
||||||
paste_back=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(restored_img[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if restored_img.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
self.gfpgan = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,118 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Outcrop(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image,
|
|
||||||
generate, # current generate object
|
|
||||||
):
|
|
||||||
self.image = image
|
|
||||||
self.generate = generate
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
extents: dict,
|
|
||||||
opt, # current options
|
|
||||||
orig_opt, # ones originally used to generate the image
|
|
||||||
image_callback=None,
|
|
||||||
prefix=None,
|
|
||||||
):
|
|
||||||
# grow and mask the image
|
|
||||||
extended_image = self._extend_all(extents)
|
|
||||||
|
|
||||||
# switch samplers temporarily
|
|
||||||
curr_sampler = self.generate.sampler
|
|
||||||
self.generate.sampler_name = opt.sampler_name
|
|
||||||
self.generate._set_scheduler()
|
|
||||||
|
|
||||||
def wrapped_callback(img, seed, **kwargs):
|
|
||||||
preferred_seed = (
|
|
||||||
orig_opt.seed
|
|
||||||
if orig_opt.seed is not None and orig_opt.seed >= 0
|
|
||||||
else seed
|
|
||||||
)
|
|
||||||
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
|
|
||||||
|
|
||||||
result = self.generate.prompt2image(
|
|
||||||
opt.prompt,
|
|
||||||
seed=opt.seed or orig_opt.seed,
|
|
||||||
sampler=self.generate.sampler,
|
|
||||||
steps=opt.steps,
|
|
||||||
cfg_scale=opt.cfg_scale,
|
|
||||||
ddim_eta=self.generate.ddim_eta,
|
|
||||||
width=extended_image.width,
|
|
||||||
height=extended_image.height,
|
|
||||||
init_img=extended_image,
|
|
||||||
strength=0.90,
|
|
||||||
image_callback=wrapped_callback if image_callback else None,
|
|
||||||
seam_size=opt.seam_size or 96,
|
|
||||||
seam_blur=opt.seam_blur or 16,
|
|
||||||
seam_strength=opt.seam_strength or 0.7,
|
|
||||||
seam_steps=20,
|
|
||||||
tile_size=32,
|
|
||||||
color_match=True,
|
|
||||||
force_outpaint=True, # this just stops the warning about erased regions
|
|
||||||
)
|
|
||||||
|
|
||||||
# swap sampler back
|
|
||||||
self.generate.sampler = curr_sampler
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _extend_all(
|
|
||||||
self,
|
|
||||||
extents: dict,
|
|
||||||
) -> Image:
|
|
||||||
"""
|
|
||||||
Extend the image in direction ('top','bottom','left','right') by
|
|
||||||
the indicated value. The image canvas is extended, and the empty
|
|
||||||
rectangular section will be filled with a blurred copy of the
|
|
||||||
adjacent image.
|
|
||||||
"""
|
|
||||||
image = self.image
|
|
||||||
for direction in extents:
|
|
||||||
assert direction in [
|
|
||||||
"top",
|
|
||||||
"left",
|
|
||||||
"bottom",
|
|
||||||
"right",
|
|
||||||
], 'Direction must be one of "top", "left", "bottom", "right"'
|
|
||||||
pixels = extents[direction]
|
|
||||||
# round pixels up to the nearest 64
|
|
||||||
pixels = math.ceil(pixels / 64) * 64
|
|
||||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
|
||||||
image = self._rotate(image, direction)
|
|
||||||
image = self._extend(image, pixels)
|
|
||||||
image = self._rotate(image, direction, reverse=True)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
|
|
||||||
"""
|
|
||||||
Rotates image so that the area to extend is always at the top top.
|
|
||||||
Simplifies logic later. The reverse argument, if true, will undo the
|
|
||||||
previous transpose.
|
|
||||||
"""
|
|
||||||
transposes = {
|
|
||||||
"right": ["ROTATE_90", "ROTATE_270"],
|
|
||||||
"bottom": ["ROTATE_180", "ROTATE_180"],
|
|
||||||
"left": ["ROTATE_270", "ROTATE_90"],
|
|
||||||
}
|
|
||||||
if direction not in transposes:
|
|
||||||
return image
|
|
||||||
transpose = transposes[direction][1 if reverse else 0]
|
|
||||||
return image.transpose(Image.Transpose.__dict__[transpose])
|
|
||||||
|
|
||||||
def _extend(self, image: Image, pixels: int) -> Image:
|
|
||||||
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
|
|
||||||
|
|
||||||
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
|
|
||||||
extended_img.paste(image, box=(0, pixels))
|
|
||||||
|
|
||||||
# now make the top part transparent to use as a mask
|
|
||||||
alpha = extended_img.getchannel("A")
|
|
||||||
alpha.paste(0, (0, 0, extended_img.width, pixels))
|
|
||||||
extended_img.putalpha(alpha)
|
|
||||||
|
|
||||||
return extended_img
|
|
@ -1,102 +0,0 @@
|
|||||||
import math
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from PIL import Image, ImageFilter
|
|
||||||
|
|
||||||
|
|
||||||
class Outpaint(object):
|
|
||||||
def __init__(self, image, generate):
|
|
||||||
self.image = image
|
|
||||||
self.generate = generate
|
|
||||||
|
|
||||||
def process(self, opt, old_opt, image_callback=None, prefix=None):
|
|
||||||
image = self._create_outpaint_image(self.image, opt.out_direction)
|
|
||||||
|
|
||||||
seed = old_opt.seed
|
|
||||||
prompt = old_opt.prompt
|
|
||||||
|
|
||||||
def wrapped_callback(img, seed, **kwargs):
|
|
||||||
image_callback(img, seed, use_prefix=prefix, **kwargs)
|
|
||||||
|
|
||||||
return self.generate.prompt2image(
|
|
||||||
prompt,
|
|
||||||
seed=seed,
|
|
||||||
sampler=self.generate.sampler,
|
|
||||||
steps=opt.steps,
|
|
||||||
cfg_scale=opt.cfg_scale,
|
|
||||||
ddim_eta=self.generate.ddim_eta,
|
|
||||||
width=opt.width,
|
|
||||||
height=opt.height,
|
|
||||||
init_img=image,
|
|
||||||
strength=0.83,
|
|
||||||
image_callback=wrapped_callback,
|
|
||||||
prefix=prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_outpaint_image(self, image, direction_args):
|
|
||||||
assert len(direction_args) in [
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
], "Direction (-D) must have exactly one or two arguments."
|
|
||||||
|
|
||||||
if len(direction_args) == 1:
|
|
||||||
direction = direction_args[0]
|
|
||||||
pixels = None
|
|
||||||
elif len(direction_args) == 2:
|
|
||||||
direction = direction_args[0]
|
|
||||||
pixels = int(direction_args[1])
|
|
||||||
|
|
||||||
assert direction in [
|
|
||||||
"top",
|
|
||||||
"left",
|
|
||||||
"bottom",
|
|
||||||
"right",
|
|
||||||
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
|
|
||||||
|
|
||||||
image = image.convert("RGBA")
|
|
||||||
# we always extend top, but rotate to extend along the requested side
|
|
||||||
if direction == "left":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
elif direction == "bottom":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif direction == "right":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
|
|
||||||
pixels = image.height // 2 if pixels is None else int(pixels)
|
|
||||||
assert (
|
|
||||||
0 < pixels < image.height
|
|
||||||
), "Direction (-D) pixels length must be in the range 0 - image.size"
|
|
||||||
|
|
||||||
# the top part of the image is taken from the source image mirrored
|
|
||||||
# coordinates (0,0) are the upper left corner of an image
|
|
||||||
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
|
|
||||||
top = top.crop((0, top.height - pixels, top.width, top.height))
|
|
||||||
|
|
||||||
# setting all alpha of the top part to 0
|
|
||||||
alpha = top.getchannel("A")
|
|
||||||
alpha.paste(0, (0, 0, top.width, top.height))
|
|
||||||
top.putalpha(alpha)
|
|
||||||
|
|
||||||
# taking the bottom from the original image
|
|
||||||
bottom = image.crop((0, 0, image.width, image.height - pixels))
|
|
||||||
|
|
||||||
new_img = image.copy()
|
|
||||||
new_img.paste(top, (0, 0))
|
|
||||||
new_img.paste(bottom, (0, pixels))
|
|
||||||
|
|
||||||
# create a 10% dither in the middle
|
|
||||||
dither = min(image.height // 10, pixels)
|
|
||||||
for x in range(0, image.width, 2):
|
|
||||||
for y in range(pixels - dither, pixels + dither):
|
|
||||||
(r, g, b, a) = new_img.getpixel((x, y))
|
|
||||||
new_img.putpixel((x, y), (r, g, b, 0))
|
|
||||||
|
|
||||||
# let's rotate back again
|
|
||||||
if direction == "left":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
elif direction == "bottom":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif direction == "right":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
|
|
||||||
return new_img
|
|
@ -1,104 +0,0 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import Image as ImageType
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
class ESRGAN:
|
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
|
||||||
self.bg_tile_size = bg_tile_size
|
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
||||||
use_half_precision = False
|
|
||||||
else:
|
|
||||||
use_half_precision = True
|
|
||||||
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
||||||
|
|
||||||
model = SRVGGNetCompact(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_conv=32,
|
|
||||||
upscale=4,
|
|
||||||
act_type="prelu",
|
|
||||||
)
|
|
||||||
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
|
||||||
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
|
||||||
scale = 4
|
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
|
||||||
scale=scale,
|
|
||||||
model_path=[model_path, wdn_model_path],
|
|
||||||
model=model,
|
|
||||||
tile=self.bg_tile_size,
|
|
||||||
dni_weight=[denoise_str, 1 - denoise_str],
|
|
||||||
tile_pad=10,
|
|
||||||
pre_pad=0,
|
|
||||||
half=use_half_precision,
|
|
||||||
)
|
|
||||||
|
|
||||||
return bg_upsampler
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
image: ImageType,
|
|
||||||
strength: float,
|
|
||||||
seed: str = None,
|
|
||||||
upsampler_scale: int = 2,
|
|
||||||
denoise_str: float = 0.75,
|
|
||||||
):
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
try:
|
|
||||||
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
|
|
||||||
except Exception:
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error("Error loading Real-ESRGAN:")
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if upsampler_scale == 0:
|
|
||||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
|
||||||
return image
|
|
||||||
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(
|
|
||||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
|
||||||
)
|
|
||||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
# REALSRGAN expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
output, _ = upsampler.enhance(
|
|
||||||
bgr_image_array,
|
|
||||||
outscale=upsampler_scale,
|
|
||||||
alpha_upsampler="realesrgan",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(output[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if output.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
upsampler = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,514 +0,0 @@
|
|||||||
"""
|
|
||||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
|
||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
|
||||||
|
|
||||||
"""
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(in_channels):
|
|
||||||
return torch.nn.GroupNorm(
|
|
||||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def swish(x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Define VQVAE classes
|
|
||||||
class VectorQuantizer(nn.Module):
|
|
||||||
def __init__(self, codebook_size, emb_dim, beta):
|
|
||||||
super(VectorQuantizer, self).__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
|
||||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
|
||||||
self.embedding.weight.data.uniform_(
|
|
||||||
-1.0 / self.codebook_size, 1.0 / self.codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
|
||||||
z = z.permute(0, 2, 3, 1).contiguous()
|
|
||||||
z_flattened = z.view(-1, self.emb_dim)
|
|
||||||
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
||||||
d = (
|
|
||||||
(z_flattened**2).sum(dim=1, keepdim=True)
|
|
||||||
+ (self.embedding.weight**2).sum(1)
|
|
||||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
||||||
)
|
|
||||||
|
|
||||||
mean_distance = torch.mean(d)
|
|
||||||
# find closest encodings
|
|
||||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
|
||||||
min_encoding_scores, min_encoding_indices = torch.topk(
|
|
||||||
d, 1, dim=1, largest=False
|
|
||||||
)
|
|
||||||
# [0-1], higher score, higher confidence
|
|
||||||
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
|
|
||||||
|
|
||||||
min_encodings = torch.zeros(
|
|
||||||
min_encoding_indices.shape[0], self.codebook_size
|
|
||||||
).to(z)
|
|
||||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
|
||||||
# compute loss for embedding
|
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
|
||||||
(z_q - z.detach()) ** 2
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
# perplexity
|
|
||||||
e_mean = torch.mean(min_encodings, dim=0)
|
|
||||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return (
|
|
||||||
z_q,
|
|
||||||
loss,
|
|
||||||
{
|
|
||||||
"perplexity": perplexity,
|
|
||||||
"min_encodings": min_encodings,
|
|
||||||
"min_encoding_indices": min_encoding_indices,
|
|
||||||
"min_encoding_scores": min_encoding_scores,
|
|
||||||
"mean_distance": mean_distance,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_codebook_feat(self, indices, shape):
|
|
||||||
# input indices: batch*token_num -> (batch*token_num)*1
|
|
||||||
# shape: batch, height, width, channel
|
|
||||||
indices = indices.view(-1, 1)
|
|
||||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
|
||||||
min_encodings.scatter_(1, indices, 1)
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
|
||||||
|
|
||||||
if shape is not None: # reshape back to match original input shape
|
|
||||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
|
|
||||||
|
|
||||||
class GumbelQuantizer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
codebook_size,
|
|
||||||
emb_dim,
|
|
||||||
num_hiddens,
|
|
||||||
straight_through=False,
|
|
||||||
kl_weight=5e-4,
|
|
||||||
temp_init=1.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.straight_through = straight_through
|
|
||||||
self.temperature = temp_init
|
|
||||||
self.kl_weight = kl_weight
|
|
||||||
self.proj = nn.Conv2d(
|
|
||||||
num_hiddens, codebook_size, 1
|
|
||||||
) # projects last encoder layer to quantized logits
|
|
||||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
hard = self.straight_through if self.training else True
|
|
||||||
|
|
||||||
logits = self.proj(z)
|
|
||||||
|
|
||||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
|
||||||
|
|
||||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
|
||||||
|
|
||||||
# + kl divergence to the prior loss
|
|
||||||
qy = F.softmax(logits, dim=1)
|
|
||||||
diff = (
|
|
||||||
self.kl_weight
|
|
||||||
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
|
||||||
)
|
|
||||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
|
||||||
|
|
||||||
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels=None):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.norm1 = normalize(in_channels)
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
self.norm2 = normalize(out_channels)
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
self.conv_out = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x_in):
|
|
||||||
x = x_in
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x_in = self.conv_out(x_in)
|
|
||||||
|
|
||||||
return x + x_in
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = normalize(in_channels)
|
|
||||||
self.q = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.k = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.v = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.proj_out = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q.shape
|
|
||||||
q = q.reshape(b, c, h * w)
|
|
||||||
q = q.permute(0, 2, 1)
|
|
||||||
k = k.reshape(b, c, h * w)
|
|
||||||
w_ = torch.bmm(q, k)
|
|
||||||
w_ = w_ * (int(c) ** (-0.5))
|
|
||||||
w_ = F.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b, c, h * w)
|
|
||||||
w_ = w_.permute(0, 2, 1)
|
|
||||||
h_ = torch.bmm(v, w_)
|
|
||||||
h_ = h_.reshape(b, c, h, w)
|
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
|
||||||
|
|
||||||
return x + h_
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
nf,
|
|
||||||
emb_dim,
|
|
||||||
ch_mult,
|
|
||||||
num_res_blocks,
|
|
||||||
resolution,
|
|
||||||
attn_resolutions,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
|
|
||||||
curr_res = self.resolution
|
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial convultion
|
|
||||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
|
||||||
|
|
||||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
|
||||||
for i in range(self.num_resolutions):
|
|
||||||
block_in_ch = nf * in_ch_mult[i]
|
|
||||||
block_out_ch = nf * ch_mult[i]
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != self.num_resolutions - 1:
|
|
||||||
blocks.append(Downsample(block_in_ch))
|
|
||||||
curr_res = curr_res // 2
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
# normalise and convert to latent size
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
|
|
||||||
)
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.num_resolutions = len(self.ch_mult)
|
|
||||||
self.num_res_blocks = res_blocks
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.in_channels = emb_dim
|
|
||||||
self.out_channels = 3
|
|
||||||
block_in_ch = self.nf * self.ch_mult[-1]
|
|
||||||
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial conv
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
for i in reversed(range(self.num_resolutions)):
|
|
||||||
block_out_ch = self.nf * self.ch_mult[i]
|
|
||||||
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
|
|
||||||
if curr_res in self.attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != 0:
|
|
||||||
blocks.append(Upsample(block_in_ch))
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(
|
|
||||||
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQAutoEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
img_size,
|
|
||||||
nf,
|
|
||||||
ch_mult,
|
|
||||||
quantizer="nearest",
|
|
||||||
res_blocks=2,
|
|
||||||
attn_resolutions=[16],
|
|
||||||
codebook_size=1024,
|
|
||||||
emb_dim=256,
|
|
||||||
beta=0.25,
|
|
||||||
gumbel_straight_through=False,
|
|
||||||
gumbel_kl_weight=1e-8,
|
|
||||||
model_path=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
logger = get_root_logger()
|
|
||||||
self.in_channels = 3
|
|
||||||
self.nf = nf
|
|
||||||
self.n_blocks = res_blocks
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.embed_dim = emb_dim
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.quantizer_type = quantizer
|
|
||||||
self.encoder = Encoder(
|
|
||||||
self.in_channels,
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions,
|
|
||||||
)
|
|
||||||
if self.quantizer_type == "nearest":
|
|
||||||
self.beta = beta # 0.25
|
|
||||||
self.quantize = VectorQuantizer(
|
|
||||||
self.codebook_size, self.embed_dim, self.beta
|
|
||||||
)
|
|
||||||
elif self.quantizer_type == "gumbel":
|
|
||||||
self.gumbel_num_hiddens = emb_dim
|
|
||||||
self.straight_through = gumbel_straight_through
|
|
||||||
self.kl_weight = gumbel_kl_weight
|
|
||||||
self.quantize = GumbelQuantizer(
|
|
||||||
self.codebook_size,
|
|
||||||
self.embed_dim,
|
|
||||||
self.gumbel_num_hiddens,
|
|
||||||
self.straight_through,
|
|
||||||
self.kl_weight,
|
|
||||||
)
|
|
||||||
self.generator = Generator(
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
|
||||||
if "params_ema" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params_ema"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
|
||||||
elif "params" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Wrong params!")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
|
||||||
x = self.generator(quant)
|
|
||||||
return x, codebook_loss, quant_stats
|
|
||||||
|
|
||||||
|
|
||||||
# patch based discriminator
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQGANDiscriminator(nn.Module):
|
|
||||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
layers = [
|
|
||||||
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
ndf_mult = 1
|
|
||||||
ndf_mult_prev = 1
|
|
||||||
for n in range(1, n_layers): # gradually increase the number of filters
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2**n, 8)
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(
|
|
||||||
ndf * ndf_mult_prev,
|
|
||||||
ndf * ndf_mult,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2**n_layers, 8)
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(
|
|
||||||
ndf * ndf_mult_prev,
|
|
||||||
ndf * ndf_mult,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
|
||||||
] # output 1 channel prediction map
|
|
||||||
self.main = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
|
||||||
if "params_d" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params_d"]
|
|
||||||
)
|
|
||||||
elif "params" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Wrong params!")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.main(x)
|
|
@ -221,7 +221,7 @@ class ControlNetData:
|
|||||||
control_mode: str = Field(default="balanced")
|
control_mode: str = Field(default="balanced")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: torch.Tensor
|
unconditioned_embeddings: torch.Tensor
|
||||||
text_embeddings: torch.Tensor
|
text_embeddings: torch.Tensor
|
||||||
@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
**kwargs,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device('cpu')
|
scheduler_device = torch.device('cpu')
|
||||||
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
additional_guidance=additional_guidance,
|
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
**kwargs,
|
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
@ -505,7 +502,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
@ -546,7 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -588,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -632,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings
|
encoder_hidden_states = conditioning_data.text_embeddings
|
||||||
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||||
conditioning_data.text_embeddings])
|
conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings,
|
||||||
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
controlnet_weight = control_datum.weight[step_index]
|
controlnet_weight = control_datum.weight[step_index]
|
||||||
@ -649,6 +646,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
|
@ -237,11 +237,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
|
||||||
# fast batched path
|
|
||||||
|
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
@ -266,16 +262,24 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return cond, encoder_attention_mask
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
x_twice = torch.cat([x] * 2)
|
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
|
||||||
|
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||||
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
|
# fast batched path
|
||||||
|
x_twice = torch.cat([x] * 2)
|
||||||
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
|
unconditioning, conditioning
|
||||||
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings,
|
x_twice, sigma_twice, both_conditionings,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
@ -293,8 +297,32 @@ class InvokeAIDiffuserComponent:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
uncond_down_block, cond_down_block = None, None
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||||
|
if down_block_additional_residuals is not None:
|
||||||
|
uncond_down_block, cond_down_block = [], []
|
||||||
|
for down_block in down_block_additional_residuals:
|
||||||
|
_uncond_down, _cond_down = down_block.chunk(2)
|
||||||
|
uncond_down_block.append(_uncond_down)
|
||||||
|
cond_down_block.append(_cond_down)
|
||||||
|
|
||||||
|
uncond_mid_block, cond_mid_block = None, None
|
||||||
|
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||||
|
if mid_block_additional_residual is not None:
|
||||||
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
|
x, sigma, unconditioning,
|
||||||
|
down_block_additional_residuals=uncond_down_block,
|
||||||
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
conditioned_next_x = self.model_forward_callback(
|
||||||
|
x, sigma, conditioning,
|
||||||
|
down_block_additional_residuals=cond_down_block,
|
||||||
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
# TODO: looks unused
|
# TODO: looks unused
|
||||||
@ -328,6 +356,20 @@ class InvokeAIDiffuserComponent:
|
|||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
|
uncond_down_block, cond_down_block = None, None
|
||||||
|
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||||
|
if down_block_additional_residuals is not None:
|
||||||
|
uncond_down_block, cond_down_block = [], []
|
||||||
|
for down_block in down_block_additional_residuals:
|
||||||
|
_uncond_down, _cond_down = down_block.chunk(2)
|
||||||
|
uncond_down_block.append(_uncond_down)
|
||||||
|
cond_down_block.append(_cond_down)
|
||||||
|
|
||||||
|
uncond_mid_block, cond_mid_block = None, None
|
||||||
|
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||||
|
if mid_block_additional_residual is not None:
|
||||||
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
cross_attn_processor_context = SwapCrossAttnContext(
|
cross_attn_processor_context = SwapCrossAttnContext(
|
||||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||||
index_map=context.cross_attention_index_map,
|
index_map=context.cross_attention_index_map,
|
||||||
@ -340,6 +382,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
down_block_additional_residuals=uncond_down_block,
|
||||||
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -352,6 +396,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
down_block_additional_residuals=cond_down_block,
|
||||||
|
mid_block_additional_residual=cond_mid_block,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
634
invokeai/backend/util/hotfixes.py
Normal file
634
invokeai/backend/util/hotfixes.py
Normal file
@ -0,0 +1,634 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||||
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
from diffusers.models.unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D,
|
||||||
|
DownBlock2D,
|
||||||
|
UNetMidBlock2DCrossAttn,
|
||||||
|
get_down_block,
|
||||||
|
)
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
|
|
||||||
|
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||||
|
|
||||||
|
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A ControlNet model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`, defaults to 4):
|
||||||
|
The number of channels in the input sample.
|
||||||
|
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||||
|
Whether to flip the sin to cos in the time embedding.
|
||||||
|
freq_shift (`int`, defaults to 0):
|
||||||
|
The frequency shift to apply to the time embedding.
|
||||||
|
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||||
|
The tuple of downsample blocks to use.
|
||||||
|
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||||
|
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||||
|
The tuple of output channels for each block.
|
||||||
|
layers_per_block (`int`, defaults to 2):
|
||||||
|
The number of layers per block.
|
||||||
|
downsample_padding (`int`, defaults to 1):
|
||||||
|
The padding to use for the downsampling convolution.
|
||||||
|
mid_block_scale_factor (`float`, defaults to 1):
|
||||||
|
The scale factor to use for the mid block.
|
||||||
|
act_fn (`str`, defaults to "silu"):
|
||||||
|
The activation function to use.
|
||||||
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||||
|
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||||
|
in post-processing.
|
||||||
|
norm_eps (`float`, defaults to 1e-5):
|
||||||
|
The epsilon to use for the normalization.
|
||||||
|
cross_attention_dim (`int`, defaults to 1280):
|
||||||
|
The dimension of the cross attention features.
|
||||||
|
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||||
|
The dimension of the attention heads.
|
||||||
|
use_linear_projection (`bool`, defaults to `False`):
|
||||||
|
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||||
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||||
|
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||||
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
upcast_attention (`bool`, defaults to `False`):
|
||||||
|
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||||
|
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||||
|
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||||
|
`class_embed_type="projection"`.
|
||||||
|
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||||
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||||
|
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||||
|
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||||
|
global_pool_conditions (`bool`, defaults to `False`):
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 4,
|
||||||
|
conditioning_channels: int = 3,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
freq_shift: int = 0,
|
||||||
|
down_block_types: Tuple[str] = (
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
),
|
||||||
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
downsample_padding: int = 1,
|
||||||
|
mid_block_scale_factor: float = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
norm_num_groups: Optional[int] = 32,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
cross_attention_dim: int = 1280,
|
||||||
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||||
|
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
class_embed_type: Optional[str] = None,
|
||||||
|
num_class_embeds: Optional[int] = None,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
|
controlnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||||
|
global_pool_conditions: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||||
|
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||||
|
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||||
|
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||||
|
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||||
|
# which is why we correct for the naming here.
|
||||||
|
num_attention_heads = num_attention_heads or attention_head_dim
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
if len(block_out_channels) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# input
|
||||||
|
conv_in_kernel = 3
|
||||||
|
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
# time
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||||
|
timestep_input_dim = block_out_channels[0]
|
||||||
|
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
timestep_input_dim,
|
||||||
|
time_embed_dim,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# class embedding
|
||||||
|
if class_embed_type is None and num_class_embeds is not None:
|
||||||
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||||
|
elif class_embed_type == "timestep":
|
||||||
|
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "identity":
|
||||||
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "projection":
|
||||||
|
if projection_class_embeddings_input_dim is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||||
|
)
|
||||||
|
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||||
|
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||||
|
# 2. it projects from an arbitrary input dimension.
|
||||||
|
#
|
||||||
|
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||||
|
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||||
|
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||||
|
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
else:
|
||||||
|
self.class_embedding = None
|
||||||
|
|
||||||
|
# control net conditioning embedding
|
||||||
|
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||||
|
conditioning_embedding_channels=block_out_channels[0],
|
||||||
|
block_out_channels=conditioning_embedding_out_channels,
|
||||||
|
conditioning_channels=conditioning_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
self.controlnet_down_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
if isinstance(only_cross_attention, bool):
|
||||||
|
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(num_attention_heads, int):
|
||||||
|
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
down_block = get_down_block(
|
||||||
|
down_block_type,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[i],
|
||||||
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
for _ in range(layers_per_block):
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
if not is_final_block:
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
mid_block_channel = block_out_channels[-1]
|
||||||
|
|
||||||
|
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_mid_block = controlnet_block
|
||||||
|
|
||||||
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=mid_block_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=mid_block_scale_factor,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unet(
|
||||||
|
cls,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
controlnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||||
|
load_weights_from_unet: bool = True,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unet (`UNet2DConditionModel`):
|
||||||
|
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||||
|
where applicable.
|
||||||
|
"""
|
||||||
|
controlnet = cls(
|
||||||
|
in_channels=unet.config.in_channels,
|
||||||
|
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||||
|
freq_shift=unet.config.freq_shift,
|
||||||
|
down_block_types=unet.config.down_block_types,
|
||||||
|
only_cross_attention=unet.config.only_cross_attention,
|
||||||
|
block_out_channels=unet.config.block_out_channels,
|
||||||
|
layers_per_block=unet.config.layers_per_block,
|
||||||
|
downsample_padding=unet.config.downsample_padding,
|
||||||
|
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||||
|
act_fn=unet.config.act_fn,
|
||||||
|
norm_num_groups=unet.config.norm_num_groups,
|
||||||
|
norm_eps=unet.config.norm_eps,
|
||||||
|
cross_attention_dim=unet.config.cross_attention_dim,
|
||||||
|
attention_head_dim=unet.config.attention_head_dim,
|
||||||
|
num_attention_heads=unet.config.num_attention_heads,
|
||||||
|
use_linear_projection=unet.config.use_linear_projection,
|
||||||
|
class_embed_type=unet.config.class_embed_type,
|
||||||
|
num_class_embeds=unet.config.num_class_embeds,
|
||||||
|
upcast_attention=unet.config.upcast_attention,
|
||||||
|
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||||
|
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||||
|
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||||
|
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_weights_from_unet:
|
||||||
|
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||||
|
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||||
|
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||||
|
|
||||||
|
if controlnet.class_embedding:
|
||||||
|
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||||
|
|
||||||
|
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
||||||
|
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||||
|
|
||||||
|
return controlnet
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.processor
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
self.set_attn_processor(AttnProcessor())
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||||
|
def set_attention_slice(self, slice_size):
|
||||||
|
r"""
|
||||||
|
Enable sliced attention computation.
|
||||||
|
|
||||||
|
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||||
|
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||||
|
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||||
|
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||||
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||||
|
must be a multiple of `slice_size`.
|
||||||
|
"""
|
||||||
|
sliceable_head_dims = []
|
||||||
|
|
||||||
|
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(child)
|
||||||
|
|
||||||
|
# retrieve number of attention layers
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(module)
|
||||||
|
|
||||||
|
num_sliceable_layers = len(sliceable_head_dims)
|
||||||
|
|
||||||
|
if slice_size == "auto":
|
||||||
|
# half the attention head size is usually a good trade-off between
|
||||||
|
# speed and memory
|
||||||
|
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||||
|
elif slice_size == "max":
|
||||||
|
# make smallest slice possible
|
||||||
|
slice_size = num_sliceable_layers * [1]
|
||||||
|
|
||||||
|
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||||
|
|
||||||
|
if len(slice_size) != len(sliceable_head_dims):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||||
|
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(slice_size)):
|
||||||
|
size = slice_size[i]
|
||||||
|
dim = sliceable_head_dims[i]
|
||||||
|
if size is not None and size > dim:
|
||||||
|
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||||
|
|
||||||
|
# Recursively walk through all the children.
|
||||||
|
# Any children which exposes the set_attention_slice method
|
||||||
|
# gets the message
|
||||||
|
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
module.set_attention_slice(slice_size.pop())
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_attention_slice(child, slice_size)
|
||||||
|
|
||||||
|
reversed_slice_size = list(reversed(slice_size))
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
controlnet_cond: torch.FloatTensor,
|
||||||
|
conditioning_scale: float = 1.0,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
guess_mode: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[ControlNetOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
The [`ControlNetModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor.
|
||||||
|
timestep (`Union[torch.Tensor, float, int]`):
|
||||||
|
The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.Tensor`):
|
||||||
|
The encoder hidden states.
|
||||||
|
controlnet_cond (`torch.FloatTensor`):
|
||||||
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
conditioning_scale (`float`, defaults to `1.0`):
|
||||||
|
The scale factor for ControlNet outputs.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||||
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
|
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
guess_mode (`bool`, defaults to `False`):
|
||||||
|
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||||
|
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||||
|
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||||
|
returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# check channel order
|
||||||
|
channel_order = self.config.controlnet_conditioning_channel_order
|
||||||
|
|
||||||
|
if channel_order == "rgb":
|
||||||
|
# in rgb order by default
|
||||||
|
...
|
||||||
|
elif channel_order == "bgr":
|
||||||
|
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||||
|
|
||||||
|
# prepare attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = sample.device.type == "mps"
|
||||||
|
if isinstance(timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
|
||||||
|
if self.class_embedding is not None:
|
||||||
|
if class_labels is None:
|
||||||
|
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||||
|
|
||||||
|
if self.config.class_embed_type == "timestep":
|
||||||
|
class_labels = self.time_proj(class_labels)
|
||||||
|
|
||||||
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||||
|
|
||||||
|
sample = sample + controlnet_cond
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Control net blocks
|
||||||
|
|
||||||
|
controlnet_down_block_res_samples = ()
|
||||||
|
|
||||||
|
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||||
|
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||||
|
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
|
down_block_res_samples = controlnet_down_block_res_samples
|
||||||
|
|
||||||
|
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||||
|
|
||||||
|
# 6. scaling
|
||||||
|
if guess_mode and not self.config.global_pool_conditions:
|
||||||
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||||
|
|
||||||
|
scales = scales * conditioning_scale
|
||||||
|
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||||
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||||
|
else:
|
||||||
|
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||||
|
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||||
|
|
||||||
|
if self.config.global_pool_conditions:
|
||||||
|
down_block_res_samples = [
|
||||||
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||||
|
]
|
||||||
|
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (down_block_res_samples, mid_block_res_sample)
|
||||||
|
|
||||||
|
return ControlNetOutput(
|
||||||
|
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
diffusers.ControlNetModel = ControlNetModel
|
||||||
|
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
@ -64,22 +64,29 @@ sd-1/main/waifu-diffusion:
|
|||||||
recommended: False
|
recommended: False
|
||||||
sd-1/controlnet/canny:
|
sd-1/controlnet/canny:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_canny
|
repo_id: lllyasviel/control_v11p_sd15_canny
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/inpaint:
|
sd-1/controlnet/inpaint:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
||||||
sd-1/controlnet/mlsd:
|
sd-1/controlnet/mlsd:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
||||||
sd-1/controlnet/depth:
|
sd-1/controlnet/depth:
|
||||||
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/normal_bae:
|
sd-1/controlnet/normal_bae:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
||||||
sd-1/controlnet/seg:
|
sd-1/controlnet/seg:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_seg
|
repo_id: lllyasviel/control_v11p_sd15_seg
|
||||||
sd-1/controlnet/lineart:
|
sd-1/controlnet/lineart:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_lineart
|
repo_id: lllyasviel/control_v11p_sd15_lineart
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/lineart_anime:
|
sd-1/controlnet/lineart_anime:
|
||||||
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||||
|
sd-1/controlnet/openpose:
|
||||||
|
repo_id: lllyasviel/control_v11p_sd15_openpose
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/scribble:
|
sd-1/controlnet/scribble:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_scribble
|
repo_id: lllyasviel/control_v11p_sd15_scribble
|
||||||
|
recommended: False
|
||||||
sd-1/controlnet/softedge:
|
sd-1/controlnet/softedge:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_softedge
|
repo_id: lllyasviel/control_v11p_sd15_softedge
|
||||||
sd-1/controlnet/shuffle:
|
sd-1/controlnet/shuffle:
|
||||||
@ -90,9 +97,11 @@ sd-1/controlnet/ip2p:
|
|||||||
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
||||||
sd-1/embedding/EasyNegative:
|
sd-1/embedding/EasyNegative:
|
||||||
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||||
|
recommended: True
|
||||||
sd-1/embedding/ahx-beta-453407d:
|
sd-1/embedding/ahx-beta-453407d:
|
||||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||||
sd-1/lora/LowRA:
|
sd-1/lora/LowRA:
|
||||||
path: https://civitai.com/api/download/models/63006
|
path: https://civitai.com/api/download/models/63006
|
||||||
|
recommended: True
|
||||||
sd-1/lora/Ink scenery:
|
sd-1/lora/Ink scenery:
|
||||||
path: https://civitai.com/api/download/models/83390
|
path: https://civitai.com/api/download/models/83390
|
||||||
|
@ -256,6 +256,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
widgets = dict()
|
widgets = dict()
|
||||||
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
|
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
|
||||||
model_labels = [self.model_labels[x] for x in model_list]
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
|
show_recommended = len(self.installed_models)==0
|
||||||
if len(model_list) > 0:
|
if len(model_list) > 0:
|
||||||
max_width = max([len(x) for x in model_labels])
|
max_width = max([len(x) for x in model_labels])
|
||||||
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
||||||
@ -280,7 +282,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
value=[
|
value=[
|
||||||
model_list.index(x)
|
model_list.index(x)
|
||||||
for x in model_list
|
for x in model_list
|
||||||
if self.all_models[x].installed
|
if (show_recommended and self.all_models[x].recommended) \
|
||||||
|
or self.all_models[x].installed
|
||||||
],
|
],
|
||||||
max_height=len(model_list)//columns + 1,
|
max_height=len(model_list)//columns + 1,
|
||||||
relx=4,
|
relx=4,
|
||||||
@ -672,7 +675,9 @@ def select_and_download_models(opt: Namespace):
|
|||||||
# pass
|
# pass
|
||||||
|
|
||||||
installer = ModelInstall(config, prediction_type_helper=helper)
|
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||||
if opt.add or opt.delete:
|
if opt.list_models:
|
||||||
|
installer.list_models(opt.list_models)
|
||||||
|
elif opt.add or opt.delete:
|
||||||
selections = InstallSelections(
|
selections = InstallSelections(
|
||||||
install_models = opt.add or [],
|
install_models = opt.add or [],
|
||||||
remove_models = opt.delete or []
|
remove_models = opt.delete or []
|
||||||
@ -696,7 +701,7 @@ def select_and_download_models(opt: Namespace):
|
|||||||
|
|
||||||
# the third argument is needed in the Windows 11 environment in
|
# the third argument is needed in the Windows 11 environment in
|
||||||
# order to launch and resize a console window running this program
|
# order to launch and resize a console window running this program
|
||||||
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-model-install')
|
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||||
installApp = AddModelApplication(opt)
|
installApp = AddModelApplication(opt)
|
||||||
try:
|
try:
|
||||||
installApp.run()
|
installApp.run()
|
||||||
@ -745,7 +750,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list-models",
|
"--list-models",
|
||||||
choices=["diffusers","loras","controlnets","tis"],
|
choices=[x.value for x in ModelType],
|
||||||
help="list installed models",
|
help="list installed models",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -773,7 +778,7 @@ def main():
|
|||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
if not (config.conf_path / 'models.yaml').exists():
|
if not config.model_conf_path.exists():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||||
)
|
)
|
||||||
|
@ -17,28 +17,20 @@ from shutil import get_terminal_size
|
|||||||
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
||||||
|
|
||||||
# minimum size for UIs
|
# minimum size for UIs
|
||||||
MIN_COLS = 130
|
MIN_COLS = 136
|
||||||
MIN_LINES = 45
|
MIN_LINES = 45
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
def set_terminal_size(columns: int, lines: int):
|
||||||
ts = get_terminal_size()
|
ts = get_terminal_size()
|
||||||
width = max(columns,ts.columns)
|
width = max(columns,ts.columns)
|
||||||
height = max(lines,ts.lines)
|
height = max(lines,ts.lines)
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
if OS == "Windows":
|
if OS == "Windows":
|
||||||
# The new Windows Terminal doesn't resize, so we relaunch in a CMD window.
|
pass
|
||||||
# Would prefer to use execvpe() here, but somehow it is not working properly
|
# not working reliably - ask user to adjust the window
|
||||||
# in the Windows 10 environment.
|
#_set_terminal_size_powershell(width,height)
|
||||||
if 'IA_RELAUNCHED' not in os.environ:
|
|
||||||
args=['conhost']
|
|
||||||
args.extend([launch_command] if launch_command else [sys.argv[0]])
|
|
||||||
args.extend(sys.argv[1:])
|
|
||||||
os.environ['IA_RELAUNCHED'] = 'True'
|
|
||||||
os.execvp('conhost',args)
|
|
||||||
else:
|
|
||||||
_set_terminal_size_powershell(width,height)
|
|
||||||
elif OS in ["Darwin", "Linux"]:
|
elif OS in ["Darwin", "Linux"]:
|
||||||
_set_terminal_size_unix(width,height)
|
_set_terminal_size_unix(width,height)
|
||||||
|
|
||||||
@ -84,20 +76,14 @@ def _set_terminal_size_unix(width: int, height: int):
|
|||||||
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
|
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
def set_min_terminal_size(min_cols: int, min_lines: int, launch_command: str=None):
|
def set_min_terminal_size(min_cols: int, min_lines: int):
|
||||||
# make sure there's enough room for the ui
|
# make sure there's enough room for the ui
|
||||||
term_cols, term_lines = get_terminal_size()
|
term_cols, term_lines = get_terminal_size()
|
||||||
if term_cols >= min_cols and term_lines >= min_lines:
|
if term_cols >= min_cols and term_lines >= min_lines:
|
||||||
return
|
return
|
||||||
cols = max(term_cols, min_cols)
|
cols = max(term_cols, min_cols)
|
||||||
lines = max(term_lines, min_lines)
|
lines = max(term_lines, min_lines)
|
||||||
set_terminal_size(cols, lines, launch_command)
|
set_terminal_size(cols, lines)
|
||||||
|
|
||||||
# did it work?
|
|
||||||
term_cols, term_lines = get_terminal_size()
|
|
||||||
if term_cols < cols or term_lines < lines:
|
|
||||||
print(f'This window is too small for optimal display. For best results please enlarge it.')
|
|
||||||
input('After resizing, press any key to continue...')
|
|
||||||
|
|
||||||
class IntSlider(npyscreen.Slider):
|
class IntSlider(npyscreen.Slider):
|
||||||
def translate_value(self):
|
def translate_value(self):
|
||||||
|
169
invokeai/frontend/web/dist/assets/App-c8b96e06.js
vendored
169
invokeai/frontend/web/dist/assets/App-c8b96e06.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-8888b06f.js
vendored
125
invokeai/frontend/web/dist/assets/index-8888b06f.js
vendored
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-8888b06f.js"></script>
|
<script type="module" crossorigin src="./assets/index-f1a5f9cf.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
@ -399,6 +399,8 @@
|
|||||||
"deleteModel": "Delete Model",
|
"deleteModel": "Delete Model",
|
||||||
"deleteConfig": "Delete Config",
|
"deleteConfig": "Delete Config",
|
||||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||||
|
"modelDeleted": "Model Deleted",
|
||||||
|
"modelDeleteFailed": "Failed to delete model",
|
||||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||||
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||||
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||||
@ -408,11 +410,13 @@
|
|||||||
"convertToDiffusers": "Convert To Diffusers",
|
"convertToDiffusers": "Convert To Diffusers",
|
||||||
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
|
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
|
||||||
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
|
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
|
||||||
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
|
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
|
||||||
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
|
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
|
||||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
||||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||||
"convertToDiffusersSaveLocation": "Save Location",
|
"convertToDiffusersSaveLocation": "Save Location",
|
||||||
|
"noCustomLocationProvided": "No Custom Location Provided",
|
||||||
|
"convertingModelBegin": "Converting Model. Please wait.",
|
||||||
"v1": "v1",
|
"v1": "v1",
|
||||||
"v2_base": "v2 (512px)",
|
"v2_base": "v2 (512px)",
|
||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
@ -450,7 +454,8 @@
|
|||||||
"none": "none",
|
"none": "none",
|
||||||
"addDifference": "Add Difference",
|
"addDifference": "Add Difference",
|
||||||
"pickModelType": "Pick Model Type",
|
"pickModelType": "Pick Model Type",
|
||||||
"selectModel": "Select Model"
|
"selectModel": "Select Model",
|
||||||
|
"importModels": "Import Models"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"general": "General",
|
"general": "General",
|
||||||
@ -572,6 +577,7 @@
|
|||||||
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
|
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
|
||||||
"downloadImageStarted": "Image Download Started",
|
"downloadImageStarted": "Image Download Started",
|
||||||
"imageCopied": "Image Copied",
|
"imageCopied": "Image Copied",
|
||||||
|
"problemCopyingImage": "Unable to Copy Image",
|
||||||
"imageLinkCopied": "Image Link Copied",
|
"imageLinkCopied": "Image Link Copied",
|
||||||
"problemCopyingImageLink": "Unable to Copy Image Link",
|
"problemCopyingImageLink": "Unable to Copy Image Link",
|
||||||
"imageNotLoaded": "No Image Loaded",
|
"imageNotLoaded": "No Image Loaded",
|
||||||
@ -688,6 +694,15 @@
|
|||||||
"reloadSchema": "Reload Schema",
|
"reloadSchema": "Reload Schema",
|
||||||
"saveNodes": "Save Nodes",
|
"saveNodes": "Save Nodes",
|
||||||
"loadNodes": "Load Nodes",
|
"loadNodes": "Load Nodes",
|
||||||
"clearNodes": "Clear Nodes"
|
"clearNodes": "Clear Nodes",
|
||||||
|
"zoomInNodes": "Zoom In",
|
||||||
|
"zoomOutNodes": "Zoom Out",
|
||||||
|
"fitViewportNodes": "Fit View",
|
||||||
|
"hideGraphNodes": "Hide Graph Overlay",
|
||||||
|
"showGraphNodes": "Show Graph Overlay",
|
||||||
|
"hideLegendNodes": "Hide Field Type Legend",
|
||||||
|
"showLegendNodes": "Show Field Type Legend",
|
||||||
|
"hideMinimapnodes": "Hide MiniMap",
|
||||||
|
"showMinimapnodes": "Show MiniMap"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
setActiveTab,
|
setActiveTab,
|
||||||
toggleGalleryPanel,
|
toggleGalleryPanel,
|
||||||
@ -14,10 +16,11 @@ import React, { memo } from 'react';
|
|||||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
|
||||||
const globalHotkeysSelector = createSelector(
|
const globalHotkeysSelector = createSelector(
|
||||||
(state: RootState) => state.hotkeys,
|
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
|
||||||
(hotkeys) => {
|
(hotkeys, ui) => {
|
||||||
const { shift } = hotkeys;
|
const { shift } = hotkeys;
|
||||||
return { shift };
|
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
||||||
|
return { shift, shouldPinGallery, shouldPinParametersPanel };
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
memoizeOptions: {
|
memoizeOptions: {
|
||||||
@ -34,7 +37,10 @@ const globalHotkeysSelector = createSelector(
|
|||||||
*/
|
*/
|
||||||
const GlobalHotkeys: React.FC = () => {
|
const GlobalHotkeys: React.FC = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shift } = useAppSelector(globalHotkeysSelector);
|
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
|
||||||
|
globalHotkeysSelector
|
||||||
|
);
|
||||||
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'*',
|
'*',
|
||||||
@ -51,18 +57,30 @@ const GlobalHotkeys: React.FC = () => {
|
|||||||
|
|
||||||
useHotkeys('o', () => {
|
useHotkeys('o', () => {
|
||||||
dispatch(toggleParametersPanel());
|
dispatch(toggleParametersPanel());
|
||||||
|
if (activeTabName === 'unifiedCanvas' && shouldPinParametersPanel) {
|
||||||
|
dispatch(requestCanvasRescale());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
useHotkeys(['shift+o'], () => {
|
useHotkeys(['shift+o'], () => {
|
||||||
dispatch(togglePinParametersPanel());
|
dispatch(togglePinParametersPanel());
|
||||||
|
if (activeTabName === 'unifiedCanvas') {
|
||||||
|
dispatch(requestCanvasRescale());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
useHotkeys('g', () => {
|
useHotkeys('g', () => {
|
||||||
dispatch(toggleGalleryPanel());
|
dispatch(toggleGalleryPanel());
|
||||||
|
if (activeTabName === 'unifiedCanvas' && shouldPinGallery) {
|
||||||
|
dispatch(requestCanvasRescale());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
useHotkeys(['shift+g'], () => {
|
useHotkeys(['shift+g'], () => {
|
||||||
dispatch(togglePinGalleryPanel());
|
dispatch(togglePinGalleryPanel());
|
||||||
|
if (activeTabName === 'unifiedCanvas') {
|
||||||
|
dispatch(requestCanvasRescale());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
useHotkeys('1', () => {
|
useHotkeys('1', () => {
|
||||||
|
@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
|
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
|
||||||
|
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -177,6 +179,8 @@ addSocketConnectedListener();
|
|||||||
addSocketDisconnectedListener();
|
addSocketDisconnectedListener();
|
||||||
addSocketSubscribedListener();
|
addSocketSubscribedListener();
|
||||||
addSocketUnsubscribedListener();
|
addSocketUnsubscribedListener();
|
||||||
|
addModelLoadStartedEventListener();
|
||||||
|
addModelLoadCompletedEventListener();
|
||||||
|
|
||||||
// Session Created
|
// Session Created
|
||||||
addSessionCreatedPendingListener();
|
addSessionCreatedPendingListener();
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketModelLoadCompleted,
|
||||||
|
socketModelLoadCompleted,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addModelLoadCompletedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelLoadCompleted,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { model_name, model_type, submodel } = action.payload.data;
|
||||||
|
|
||||||
|
let modelString = `${model_type} model: ${model_name}`;
|
||||||
|
|
||||||
|
if (submodel) {
|
||||||
|
modelString = modelString.concat(`, submodel: ${submodel}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
|
||||||
|
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketModelLoadCompleted(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,28 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketModelLoadStarted,
|
||||||
|
socketModelLoadStarted,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addModelLoadStartedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelLoadStarted,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { model_name, model_type, submodel } = action.payload.data;
|
||||||
|
|
||||||
|
let modelString = `${model_type} model: ${model_name}`;
|
||||||
|
|
||||||
|
if (submodel) {
|
||||||
|
modelString = modelString.concat(`, submodel: ${submodel}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.debug(action.payload, `Model load started (${modelString})`);
|
||||||
|
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketModelLoadStarted(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -21,6 +21,7 @@ import generationReducer from 'features/parameters/store/generationSlice';
|
|||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
|
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
|
|
||||||
@ -49,6 +50,7 @@ const allReducers = {
|
|||||||
dynamicPrompts: dynamicPromptsReducer,
|
dynamicPrompts: dynamicPromptsReducer,
|
||||||
imageDeletion: imageDeletionReducer,
|
imageDeletion: imageDeletionReducer,
|
||||||
lora: loraReducer,
|
lora: loraReducer,
|
||||||
|
modelmanager: modelmanagerReducer,
|
||||||
[api.reducerPath]: api.reducer,
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -67,6 +69,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'controlNet',
|
'controlNet',
|
||||||
'dynamicPrompts',
|
'dynamicPrompts',
|
||||||
'lora',
|
'lora',
|
||||||
|
'modelmanager',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const store = configureStore({
|
export const store = configureStore({
|
||||||
|
@ -21,6 +21,7 @@ import { ImageDTO } from 'services/api/types';
|
|||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
import IAIDraggable from './IAIDraggable';
|
import IAIDraggable from './IAIDraggable';
|
||||||
import IAIDroppable from './IAIDroppable';
|
import IAIDroppable from './IAIDroppable';
|
||||||
|
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||||
|
|
||||||
type IAIDndImageProps = {
|
type IAIDndImageProps = {
|
||||||
imageDTO: ImageDTO | undefined;
|
imageDTO: ImageDTO | undefined;
|
||||||
@ -96,119 +97,124 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<ImageContextMenu imageDTO={imageDTO}>
|
||||||
sx={{
|
{(ref) => (
|
||||||
width: 'full',
|
|
||||||
height: 'full',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
position: 'relative',
|
|
||||||
minW: minSize ? minSize : undefined,
|
|
||||||
minH: minSize ? minSize : undefined,
|
|
||||||
userSelect: 'none',
|
|
||||||
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{imageDTO && (
|
|
||||||
<Flex
|
<Flex
|
||||||
|
ref={ref}
|
||||||
sx={{
|
sx={{
|
||||||
w: 'full',
|
width: 'full',
|
||||||
h: 'full',
|
height: 'full',
|
||||||
position: fitContainer ? 'absolute' : 'relative',
|
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
|
position: 'relative',
|
||||||
|
minW: minSize ? minSize : undefined,
|
||||||
|
minH: minSize ? minSize : undefined,
|
||||||
|
userSelect: 'none',
|
||||||
|
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Image
|
{imageDTO && (
|
||||||
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
<Flex
|
||||||
fallbackStrategy="beforeLoadOrError"
|
|
||||||
// If we fall back to thumbnail, it feels much snappier than the skeleton...
|
|
||||||
fallbackSrc={imageDTO.thumbnail_url}
|
|
||||||
// fallback={<IAILoadingImageFallback image={imageDTO} />}
|
|
||||||
width={imageDTO.width}
|
|
||||||
height={imageDTO.height}
|
|
||||||
onError={onError}
|
|
||||||
draggable={false}
|
|
||||||
sx={{
|
|
||||||
objectFit: 'contain',
|
|
||||||
maxW: 'full',
|
|
||||||
maxH: 'full',
|
|
||||||
borderRadius: 'base',
|
|
||||||
shadow: isSelected ? 'selected.light' : undefined,
|
|
||||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
|
||||||
...imageSx,
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
{!imageDTO && !isUploadDisabled && (
|
|
||||||
<>
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
minH: minSize,
|
|
||||||
w: 'full',
|
|
||||||
h: 'full',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
borderRadius: 'base',
|
|
||||||
transitionProperty: 'common',
|
|
||||||
transitionDuration: '0.1s',
|
|
||||||
color: mode('base.500', 'base.500')(colorMode),
|
|
||||||
...uploadButtonStyles,
|
|
||||||
}}
|
|
||||||
{...getUploadButtonProps()}
|
|
||||||
>
|
|
||||||
<input {...getUploadInputProps()} />
|
|
||||||
<Icon
|
|
||||||
as={FaUpload}
|
|
||||||
sx={{
|
sx={{
|
||||||
boxSize: 16,
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
position: fitContainer ? 'absolute' : 'relative',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Image
|
||||||
|
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
||||||
|
fallbackStrategy="beforeLoadOrError"
|
||||||
|
// If we fall back to thumbnail, it feels much snappier than the skeleton...
|
||||||
|
fallbackSrc={imageDTO.thumbnail_url}
|
||||||
|
// fallback={<IAILoadingImageFallback image={imageDTO} />}
|
||||||
|
width={imageDTO.width}
|
||||||
|
height={imageDTO.height}
|
||||||
|
onError={onError}
|
||||||
|
draggable={false}
|
||||||
|
sx={{
|
||||||
|
objectFit: 'contain',
|
||||||
|
maxW: 'full',
|
||||||
|
maxH: 'full',
|
||||||
|
borderRadius: 'base',
|
||||||
|
shadow: isSelected ? 'selected.light' : undefined,
|
||||||
|
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||||
|
...imageSx,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
{!imageDTO && !isUploadDisabled && (
|
||||||
|
<>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
minH: minSize,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.1s',
|
||||||
|
color: mode('base.500', 'base.500')(colorMode),
|
||||||
|
...uploadButtonStyles,
|
||||||
|
}}
|
||||||
|
{...getUploadButtonProps()}
|
||||||
|
>
|
||||||
|
<input {...getUploadInputProps()} />
|
||||||
|
<Icon
|
||||||
|
as={FaUpload}
|
||||||
|
sx={{
|
||||||
|
boxSize: 16,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||||
|
{!isDropDisabled && (
|
||||||
|
<IAIDroppable
|
||||||
|
data={droppableData}
|
||||||
|
disabled={isDropDisabled}
|
||||||
|
dropLabel={dropLabel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{imageDTO && !isDragDisabled && (
|
||||||
|
<IAIDraggable
|
||||||
|
data={draggableData}
|
||||||
|
disabled={isDragDisabled || !imageDTO}
|
||||||
|
onClick={onClick}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{onClickReset && withResetIcon && imageDTO && (
|
||||||
|
<IAIIconButton
|
||||||
|
onClick={onClickReset}
|
||||||
|
aria-label={resetTooltip}
|
||||||
|
tooltip={resetTooltip}
|
||||||
|
icon={resetIcon}
|
||||||
|
size="sm"
|
||||||
|
variant="link"
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 1,
|
||||||
|
insetInlineEnd: 1,
|
||||||
|
p: 0,
|
||||||
|
minW: 0,
|
||||||
|
svg: {
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
fill: 'base.100',
|
||||||
|
_hover: { fill: 'base.50' },
|
||||||
|
filter: resetIconShadow,
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
)}
|
||||||
</>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
</ImageContextMenu>
|
||||||
{!isDropDisabled && (
|
|
||||||
<IAIDroppable
|
|
||||||
data={droppableData}
|
|
||||||
disabled={isDropDisabled}
|
|
||||||
dropLabel={dropLabel}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{imageDTO && !isDragDisabled && (
|
|
||||||
<IAIDraggable
|
|
||||||
data={draggableData}
|
|
||||||
disabled={isDragDisabled || !imageDTO}
|
|
||||||
onClick={onClick}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{onClickReset && withResetIcon && imageDTO && (
|
|
||||||
<IAIIconButton
|
|
||||||
onClick={onClickReset}
|
|
||||||
aria-label={resetTooltip}
|
|
||||||
tooltip={resetTooltip}
|
|
||||||
icon={resetIcon}
|
|
||||||
size="sm"
|
|
||||||
variant="link"
|
|
||||||
sx={{
|
|
||||||
position: 'absolute',
|
|
||||||
top: 1,
|
|
||||||
insetInlineEnd: 1,
|
|
||||||
p: 0,
|
|
||||||
minW: 0,
|
|
||||||
svg: {
|
|
||||||
transitionProperty: 'common',
|
|
||||||
transitionDuration: 'normal',
|
|
||||||
fill: 'base.100',
|
|
||||||
_hover: { fill: 'base.50' },
|
|
||||||
filter: resetIconShadow,
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -8,19 +8,34 @@ import {
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||||
import { ChangeEvent, KeyboardEvent, memo, useCallback } from 'react';
|
import {
|
||||||
|
CSSProperties,
|
||||||
|
ChangeEvent,
|
||||||
|
KeyboardEvent,
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
} from 'react';
|
||||||
|
|
||||||
interface IAIInputProps extends InputProps {
|
interface IAIInputProps extends InputProps {
|
||||||
label?: string;
|
label?: string;
|
||||||
|
labelPos?: 'top' | 'side';
|
||||||
value?: string;
|
value?: string;
|
||||||
size?: string;
|
size?: string;
|
||||||
onChange?: (e: ChangeEvent<HTMLInputElement>) => void;
|
onChange?: (e: ChangeEvent<HTMLInputElement>) => void;
|
||||||
formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>;
|
formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const labelPosVerticalStyle: CSSProperties = {
|
||||||
|
display: 'flex',
|
||||||
|
flexDirection: 'row',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: 10,
|
||||||
|
};
|
||||||
|
|
||||||
const IAIInput = (props: IAIInputProps) => {
|
const IAIInput = (props: IAIInputProps) => {
|
||||||
const {
|
const {
|
||||||
label = '',
|
label = '',
|
||||||
|
labelPos = 'top',
|
||||||
isDisabled = false,
|
isDisabled = false,
|
||||||
isInvalid,
|
isInvalid,
|
||||||
formControlProps,
|
formControlProps,
|
||||||
@ -51,6 +66,7 @@ const IAIInput = (props: IAIInputProps) => {
|
|||||||
isInvalid={isInvalid}
|
isInvalid={isInvalid}
|
||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
{...formControlProps}
|
{...formControlProps}
|
||||||
|
style={labelPos === 'side' ? labelPosVerticalStyle : undefined}
|
||||||
>
|
>
|
||||||
{label !== '' && <FormLabel>{label}</FormLabel>}
|
{label !== '' && <FormLabel>{label}</FormLabel>}
|
||||||
<Input
|
<Input
|
||||||
|
@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
|||||||
label: {
|
label: {
|
||||||
color: mode(base700, base300)(colorMode),
|
color: mode(base700, base300)(colorMode),
|
||||||
fontWeight: 'normal',
|
fontWeight: 'normal',
|
||||||
|
marginBottom: 4,
|
||||||
},
|
},
|
||||||
})}
|
})}
|
||||||
{...rest}
|
{...rest}
|
||||||
|
@ -9,14 +9,14 @@ export type IAISelectDataType = {
|
|||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
export type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
inputRef?: RefObject<HTMLInputElement>;
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
label?: string;
|
label?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||||
const { tooltip, inputRef, label, disabled, ...rest } = props;
|
const { tooltip, inputRef, label, disabled, required, ...rest } = props;
|
||||||
|
|
||||||
const styles = useMantineSelectStyles();
|
const styles = useMantineSelectStyles();
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
|||||||
<Select
|
<Select
|
||||||
label={
|
label={
|
||||||
label ? (
|
label ? (
|
||||||
<FormControl isDisabled={disabled}>
|
<FormControl isRequired={required} isDisabled={disabled}>
|
||||||
<FormLabel>{label}</FormLabel>
|
<FormLabel>{label}</FormLabel>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
) : undefined
|
) : undefined
|
||||||
|
@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
|
|||||||
/**
|
/**
|
||||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||||
*/
|
*/
|
||||||
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
export const getTimestamp = () =>
|
||||||
|
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
setIsMouseOverBoundingBox,
|
setIsMouseOverBoundingBox,
|
||||||
setIsMovingBoundingBox,
|
setIsMovingBoundingBox,
|
||||||
setIsTransformingBoundingBox,
|
setIsTransformingBoundingBox,
|
||||||
|
setShouldSnapToGrid,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
@ -20,6 +21,7 @@ import { Vector2d } from 'konva/lib/types';
|
|||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { Group, Rect, Transformer } from 'react-konva';
|
import { Group, Rect, Transformer } from 'react-konva';
|
||||||
|
|
||||||
const boundingBoxPreviewSelector = createSelector(
|
const boundingBoxPreviewSelector = createSelector(
|
||||||
@ -91,6 +93,10 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
|||||||
|
|
||||||
const scaledStep = 64 * stageScale;
|
const scaledStep = 64 * stageScale;
|
||||||
|
|
||||||
|
useHotkeys('N', () => {
|
||||||
|
dispatch(setShouldSnapToGrid(!shouldSnapToGrid));
|
||||||
|
});
|
||||||
|
|
||||||
const handleOnDragMove = useCallback(
|
const handleOnDragMove = useCallback(
|
||||||
(e: KonvaEventObject<DragEvent>) => {
|
(e: KonvaEventObject<DragEvent>) => {
|
||||||
if (!shouldSnapToGrid) {
|
if (!shouldSnapToGrid) {
|
||||||
|
@ -139,7 +139,7 @@ const IAICanvasToolChooserOptions = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
['shift+BracketLeft'],
|
['Shift+BracketLeft'],
|
||||||
() => {
|
() => {
|
||||||
dispatch(
|
dispatch(
|
||||||
setBrushColor({
|
setBrushColor({
|
||||||
@ -156,7 +156,7 @@ const IAICanvasToolChooserOptions = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
['shift+BracketRight'],
|
['Shift+BracketRight'],
|
||||||
() => {
|
() => {
|
||||||
dispatch(
|
dispatch(
|
||||||
setBrushColor({
|
setBrushColor({
|
||||||
|
@ -48,6 +48,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
|
|||||||
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
||||||
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
||||||
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
||||||
|
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[systemSelector, canvasSelector, isStagingSelector],
|
[systemSelector, canvasSelector, isStagingSelector],
|
||||||
@ -79,6 +80,7 @@ const IAICanvasToolbar = () => {
|
|||||||
const canvasBaseLayer = getCanvasBaseLayer();
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
|
||||||
|
|
||||||
const { openUploader } = useImageUploader();
|
const { openUploader } = useImageUploader();
|
||||||
|
|
||||||
@ -136,10 +138,10 @@ const IAICanvasToolbar = () => {
|
|||||||
handleCopyImageToClipboard();
|
handleCopyImageToClipboard();
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
enabled: () => !isStaging,
|
enabled: () => !isStaging && isClipboardAPIAvailable,
|
||||||
preventDefault: true,
|
preventDefault: true,
|
||||||
},
|
},
|
||||||
[canvasBaseLayer, isProcessing]
|
[canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -189,6 +191,9 @@ const IAICanvasToolbar = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleCopyImageToClipboard = () => {
|
const handleCopyImageToClipboard = () => {
|
||||||
|
if (!isClipboardAPIAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
dispatch(canvasCopiedToClipboard());
|
dispatch(canvasCopiedToClipboard());
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -256,13 +261,15 @@ const IAICanvasToolbar = () => {
|
|||||||
onClick={handleSaveToGallery}
|
onClick={handleSaveToGallery}
|
||||||
isDisabled={isStaging}
|
isDisabled={isStaging}
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
{isClipboardAPIAvailable && (
|
||||||
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
<IAIIconButton
|
||||||
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||||
icon={<FaCopy />}
|
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||||
onClick={handleCopyImageToClipboard}
|
icon={<FaCopy />}
|
||||||
isDisabled={isStaging}
|
onClick={handleCopyImageToClipboard}
|
||||||
/>
|
isDisabled={isStaging}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||||
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||||
|
@ -1,7 +1,16 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { ButtonGroup, Flex, FlexProps, Link } from '@chakra-ui/react';
|
import {
|
||||||
|
ButtonGroup,
|
||||||
|
Flex,
|
||||||
|
FlexProps,
|
||||||
|
Link,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuItem,
|
||||||
|
MenuList,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
// import { runESRGAN, runFacetool } from 'app/socketio/actions';
|
// import { runESRGAN, runFacetool } from 'app/socketio/actions';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
@ -20,6 +29,7 @@ import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/U
|
|||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
setActiveTab,
|
setActiveTab,
|
||||||
@ -48,6 +58,8 @@ import {
|
|||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
import { useDebounce } from 'use-debounce';
|
import { useDebounce } from 'use-debounce';
|
||||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||||
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
|
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||||
|
|
||||||
const currentImageButtonsSelector = createSelector(
|
const currentImageButtonsSelector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
@ -120,6 +132,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||||
|
useCopyImageToClipboard();
|
||||||
|
|
||||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||||
useRecallParameters();
|
useRecallParameters();
|
||||||
|
|
||||||
@ -128,7 +143,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
500
|
500
|
||||||
);
|
);
|
||||||
|
|
||||||
const { currentData: image, isFetching } = useGetImageDTOQuery(
|
const { currentData: imageDTO, isFetching } = useGetImageDTOQuery(
|
||||||
lastSelectedImage ?? skipToken
|
lastSelectedImage ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -142,15 +157,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
|
|
||||||
const handleCopyImageLink = useCallback(() => {
|
const handleCopyImageLink = useCallback(() => {
|
||||||
const getImageUrl = () => {
|
const getImageUrl = () => {
|
||||||
if (!image) {
|
if (!imageDTO) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (image.image_url.startsWith('http')) {
|
if (imageDTO.image_url.startsWith('http')) {
|
||||||
return image.image_url;
|
return imageDTO.image_url;
|
||||||
}
|
}
|
||||||
|
|
||||||
return window.location.toString() + image.image_url;
|
return window.location.toString() + imageDTO.image_url;
|
||||||
};
|
};
|
||||||
|
|
||||||
const url = getImageUrl();
|
const url = getImageUrl();
|
||||||
@ -174,7 +189,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
isClosable: true,
|
isClosable: true,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [toaster, t, image]);
|
}, [toaster, t, imageDTO]);
|
||||||
|
|
||||||
const handleClickUseAllParameters = useCallback(() => {
|
const handleClickUseAllParameters = useCallback(() => {
|
||||||
recallAllParameters(metadata);
|
recallAllParameters(metadata);
|
||||||
@ -192,31 +207,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
recallSeed(metadata?.seed);
|
recallSeed(metadata?.seed);
|
||||||
}, [metadata?.seed, recallSeed]);
|
}, [metadata?.seed, recallSeed]);
|
||||||
|
|
||||||
useHotkeys('s', handleUseSeed, [image]);
|
useHotkeys('s', handleUseSeed, [imageDTO]);
|
||||||
|
|
||||||
const handleUsePrompt = useCallback(() => {
|
const handleUsePrompt = useCallback(() => {
|
||||||
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
|
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
|
||||||
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
|
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
|
||||||
|
|
||||||
useHotkeys('p', handleUsePrompt, [image]);
|
useHotkeys('p', handleUsePrompt, [imageDTO]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
dispatch(sentImageToImg2Img());
|
dispatch(sentImageToImg2Img());
|
||||||
dispatch(initialImageSelected(image));
|
dispatch(initialImageSelected(imageDTO));
|
||||||
}, [dispatch, image]);
|
}, [dispatch, imageDTO]);
|
||||||
|
|
||||||
useHotkeys('shift+i', handleSendToImageToImage, [image]);
|
useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
|
||||||
|
|
||||||
const handleClickUpscale = useCallback(() => {
|
const handleClickUpscale = useCallback(() => {
|
||||||
// selectedImage && dispatch(runESRGAN(selectedImage));
|
// selectedImage && dispatch(runESRGAN(selectedImage));
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
if (!image) {
|
if (!imageDTO) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(imageToDeleteSelected(image));
|
dispatch(imageToDeleteSelected(imageDTO));
|
||||||
}, [dispatch, image]);
|
}, [dispatch, imageDTO]);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'Shift+U',
|
'Shift+U',
|
||||||
@ -236,7 +251,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
},
|
},
|
||||||
[
|
[
|
||||||
isUpscalingEnabled,
|
isUpscalingEnabled,
|
||||||
image,
|
imageDTO,
|
||||||
isESRGANAvailable,
|
isESRGANAvailable,
|
||||||
shouldDisableToolbarButtons,
|
shouldDisableToolbarButtons,
|
||||||
isConnected,
|
isConnected,
|
||||||
@ -268,7 +283,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
|
|
||||||
[
|
[
|
||||||
isFaceRestoreEnabled,
|
isFaceRestoreEnabled,
|
||||||
image,
|
imageDTO,
|
||||||
isGFPGANAvailable,
|
isGFPGANAvailable,
|
||||||
shouldDisableToolbarButtons,
|
shouldDisableToolbarButtons,
|
||||||
isConnected,
|
isConnected,
|
||||||
@ -283,10 +298,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleSendToCanvas = useCallback(() => {
|
const handleSendToCanvas = useCallback(() => {
|
||||||
if (!image) return;
|
if (!imageDTO) return;
|
||||||
dispatch(sentImageToCanvas());
|
dispatch(sentImageToCanvas());
|
||||||
|
|
||||||
dispatch(setInitialCanvasImage(image));
|
dispatch(setInitialCanvasImage(imageDTO));
|
||||||
dispatch(requestCanvasRescale());
|
dispatch(requestCanvasRescale());
|
||||||
|
|
||||||
if (activeTabName !== 'unifiedCanvas') {
|
if (activeTabName !== 'unifiedCanvas') {
|
||||||
@ -299,12 +314,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
duration: 2500,
|
duration: 2500,
|
||||||
isClosable: true,
|
isClosable: true,
|
||||||
});
|
});
|
||||||
}, [image, dispatch, activeTabName, toaster, t]);
|
}, [imageDTO, dispatch, activeTabName, toaster, t]);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'i',
|
'i',
|
||||||
() => {
|
() => {
|
||||||
if (image) {
|
if (imageDTO) {
|
||||||
handleClickShowImageDetails();
|
handleClickShowImageDetails();
|
||||||
} else {
|
} else {
|
||||||
toaster({
|
toaster({
|
||||||
@ -315,13 +330,20 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[image, shouldShowImageDetails, toaster]
|
[imageDTO, shouldShowImageDetails, toaster]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleClickProgressImagesToggle = useCallback(() => {
|
const handleClickProgressImagesToggle = useCallback(() => {
|
||||||
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
|
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
|
||||||
}, [dispatch, shouldShowProgressInViewer]);
|
}, [dispatch, shouldShowProgressInViewer]);
|
||||||
|
|
||||||
|
const handleCopyImage = useCallback(() => {
|
||||||
|
if (!imageDTO) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
copyImageToClipboard(imageDTO.image_url);
|
||||||
|
}, [copyImageToClipboard, imageDTO]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex
|
<Flex
|
||||||
@ -334,63 +356,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||||
<IAIPopover
|
<Menu>
|
||||||
triggerComponent={
|
<MenuButton
|
||||||
<IAIIconButton
|
as={IAIIconButton}
|
||||||
aria-label={`${t('parameters.sendTo')}...`}
|
aria-label={`${t('parameters.sendTo')}...`}
|
||||||
tooltip={`${t('parameters.sendTo')}...`}
|
tooltip={`${t('parameters.sendTo')}...`}
|
||||||
isDisabled={!image}
|
isDisabled={!imageDTO}
|
||||||
icon={<FaShareAlt />}
|
icon={<FaShareAlt />}
|
||||||
/>
|
/>
|
||||||
}
|
<MenuList motionProps={menuListMotionProps}>
|
||||||
>
|
{imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}
|
||||||
<Flex
|
</MenuList>
|
||||||
sx={{
|
</Menu>
|
||||||
flexDirection: 'column',
|
|
||||||
rowGap: 2,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<IAIButton
|
|
||||||
size="sm"
|
|
||||||
onClick={handleSendToImageToImage}
|
|
||||||
leftIcon={<FaShare />}
|
|
||||||
id="send-to-img2img"
|
|
||||||
>
|
|
||||||
{t('parameters.sendToImg2Img')}
|
|
||||||
</IAIButton>
|
|
||||||
{isCanvasEnabled && (
|
|
||||||
<IAIButton
|
|
||||||
size="sm"
|
|
||||||
onClick={handleSendToCanvas}
|
|
||||||
leftIcon={<FaShare />}
|
|
||||||
id="send-to-canvas"
|
|
||||||
>
|
|
||||||
{t('parameters.sendToUnifiedCanvas')}
|
|
||||||
</IAIButton>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* <IAIButton
|
|
||||||
size="sm"
|
|
||||||
onClick={handleCopyImage}
|
|
||||||
leftIcon={<FaCopy />}
|
|
||||||
>
|
|
||||||
{t('parameters.copyImage')}
|
|
||||||
</IAIButton> */}
|
|
||||||
<IAIButton
|
|
||||||
size="sm"
|
|
||||||
onClick={handleCopyImageLink}
|
|
||||||
leftIcon={<FaCopy />}
|
|
||||||
>
|
|
||||||
{t('parameters.copyImageToLink')}
|
|
||||||
</IAIButton>
|
|
||||||
|
|
||||||
<Link download={true} href={image?.image_url} target="_blank">
|
|
||||||
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
|
||||||
{t('parameters.downloadImage')}
|
|
||||||
</IAIButton>
|
|
||||||
</Link>
|
|
||||||
</Flex>
|
|
||||||
</IAIPopover>
|
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
|
||||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||||
@ -443,7 +420,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
<IAIButton
|
<IAIButton
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!isGFPGANAvailable ||
|
!isGFPGANAvailable ||
|
||||||
!image ||
|
!imageDTO ||
|
||||||
!(isConnected && !isProcessing) ||
|
!(isConnected && !isProcessing) ||
|
||||||
!facetoolStrength
|
!facetoolStrength
|
||||||
}
|
}
|
||||||
@ -474,7 +451,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
<IAIButton
|
<IAIButton
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!isESRGANAvailable ||
|
!isESRGANAvailable ||
|
||||||
!image ||
|
!imageDTO ||
|
||||||
!(isConnected && !isProcessing) ||
|
!(isConnected && !isProcessing) ||
|
||||||
!upscalingLevel
|
!upscalingLevel
|
||||||
}
|
}
|
||||||
|
@ -4,13 +4,14 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
||||||
import { memo, useMemo } from 'react';
|
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
|
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
|
||||||
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
|
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
imageDTO: ImageDTO;
|
imageDTO: ImageDTO | undefined;
|
||||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -31,18 +32,32 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
|||||||
|
|
||||||
const { selectionCount } = useAppSelector(selector);
|
const { selectionCount } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const handleContextMenu = useCallback((e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ContextMenu<HTMLDivElement>
|
<ContextMenu<HTMLDivElement>
|
||||||
menuProps={{ size: 'sm', isLazy: true }}
|
menuProps={{ size: 'sm', isLazy: true }}
|
||||||
renderMenu={() => (
|
menuButtonProps={{
|
||||||
<MenuList sx={{ visibility: 'visible !important' }}>
|
bg: 'transparent',
|
||||||
{selectionCount === 1 ? (
|
_hover: { bg: 'transparent' },
|
||||||
<SingleSelectionMenuItems imageDTO={imageDTO} />
|
}}
|
||||||
) : (
|
renderMenu={() =>
|
||||||
<MultipleSelectionMenuItems />
|
imageDTO ? (
|
||||||
)}
|
<MenuList
|
||||||
</MenuList>
|
sx={{ visibility: 'visible !important' }}
|
||||||
)}
|
motionProps={menuListMotionProps}
|
||||||
|
onContextMenu={handleContextMenu}
|
||||||
|
>
|
||||||
|
{selectionCount === 1 ? (
|
||||||
|
<SingleSelectionMenuItems imageDTO={imageDTO} />
|
||||||
|
) : (
|
||||||
|
<MultipleSelectionMenuItems />
|
||||||
|
)}
|
||||||
|
</MenuList>
|
||||||
|
) : null
|
||||||
|
}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
</ContextMenu>
|
</ContextMenu>
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
import { Link, MenuItem } from '@chakra-ui/react';
|
||||||
import { MenuItem } from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
@ -14,11 +13,21 @@ import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletio
|
|||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { memo, useCallback, useContext, useMemo } from 'react';
|
import { memo, useCallback, useContext, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaFolder, FaShare, FaTrash } from 'react-icons/fa';
|
import {
|
||||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
FaAsterisk,
|
||||||
|
FaCopy,
|
||||||
|
FaDownload,
|
||||||
|
FaExternalLinkAlt,
|
||||||
|
FaFolder,
|
||||||
|
FaQuoteRight,
|
||||||
|
FaSeedling,
|
||||||
|
FaShare,
|
||||||
|
FaTrash,
|
||||||
|
} from 'react-icons/fa';
|
||||||
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
|
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
|
||||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
@ -61,6 +70,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
|
|
||||||
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
|
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
|
||||||
|
|
||||||
|
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||||
|
useCopyImageToClipboard();
|
||||||
|
|
||||||
const metadata = currentData?.metadata;
|
const metadata = currentData?.metadata;
|
||||||
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
@ -130,13 +142,27 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
dispatch(imagesAddedToBatch([imageDTO.image_name]));
|
dispatch(imagesAddedToBatch([imageDTO.image_name]));
|
||||||
}, [dispatch, imageDTO.image_name]);
|
}, [dispatch, imageDTO.image_name]);
|
||||||
|
|
||||||
|
const handleCopyImage = useCallback(() => {
|
||||||
|
copyImageToClipboard(imageDTO.image_url);
|
||||||
|
}, [copyImageToClipboard, imageDTO.image_url]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}>
|
<Link href={imageDTO.image_url} target="_blank">
|
||||||
{t('common.openInNewTab')}
|
<MenuItem
|
||||||
</MenuItem>
|
icon={<FaExternalLinkAlt />}
|
||||||
|
onClickCapture={handleOpenInNewTab}
|
||||||
|
>
|
||||||
|
{t('common.openInNewTab')}
|
||||||
|
</MenuItem>
|
||||||
|
</Link>
|
||||||
|
{isClipboardAPIAvailable && (
|
||||||
|
<MenuItem icon={<FaCopy />} onClickCapture={handleCopyImage}>
|
||||||
|
{t('parameters.copyImage')}
|
||||||
|
</MenuItem>
|
||||||
|
)}
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<FaQuoteRight />}
|
||||||
onClickCapture={handleRecallPrompt}
|
onClickCapture={handleRecallPrompt}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
metadata?.positive_prompt === undefined &&
|
metadata?.positive_prompt === undefined &&
|
||||||
@ -147,14 +173,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<FaSeedling />}
|
||||||
onClickCapture={handleRecallSeed}
|
onClickCapture={handleRecallSeed}
|
||||||
isDisabled={metadata?.seed === undefined}
|
isDisabled={metadata?.seed === undefined}
|
||||||
>
|
>
|
||||||
{t('parameters.useSeed')}
|
{t('parameters.useSeed')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<FaAsterisk />}
|
||||||
onClickCapture={handleUseAllParameters}
|
onClickCapture={handleUseAllParameters}
|
||||||
isDisabled={!metadata}
|
isDisabled={!metadata}
|
||||||
>
|
>
|
||||||
@ -193,6 +219,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
Remove from Board
|
Remove from Board
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
)}
|
)}
|
||||||
|
<Link download={true} href={imageDTO.image_url} target="_blank">
|
||||||
|
<MenuItem icon={<FaDownload />} w="100%">
|
||||||
|
{t('parameters.downloadImage')}
|
||||||
|
</MenuItem>
|
||||||
|
</Link>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||||
icon={<FaTrash />}
|
icon={<FaTrash />}
|
||||||
|
@ -16,14 +16,13 @@ import {
|
|||||||
ASSETS_CATEGORIES,
|
ASSETS_CATEGORIES,
|
||||||
IMAGE_CATEGORIES,
|
IMAGE_CATEGORIES,
|
||||||
IMAGE_LIMIT,
|
IMAGE_LIMIT,
|
||||||
selectImagesAll,
|
|
||||||
} from 'features/gallery//store/gallerySlice';
|
} from 'features/gallery//store/gallerySlice';
|
||||||
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
|
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
|
||||||
import { VirtuosoGrid } from 'react-virtuoso';
|
import { VirtuosoGrid } from 'react-virtuoso';
|
||||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||||
|
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
|
||||||
import ImageGridItemContainer from './ImageGridItemContainer';
|
import ImageGridItemContainer from './ImageGridItemContainer';
|
||||||
import ImageGridListContainer from './ImageGridListContainer';
|
import ImageGridListContainer from './ImageGridListContainer';
|
||||||
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector, selectFilteredImages],
|
[stateSelector, selectFilteredImages],
|
||||||
@ -180,7 +179,6 @@ const GalleryImageGrid = () => {
|
|||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
console.log({ selectedBoardId });
|
|
||||||
|
|
||||||
if (status !== 'rejected') {
|
if (status !== 'rejected') {
|
||||||
return (
|
return (
|
||||||
|
@ -110,8 +110,11 @@ const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
|||||||
return (
|
return (
|
||||||
<div ref={ref} {...others}>
|
<div ref={ref} {...others}>
|
||||||
<div>
|
<div>
|
||||||
<Text>{label}</Text>
|
<Text fontWeight={600}>{label}</Text>
|
||||||
<Text size="xs" color="base.600">
|
<Text
|
||||||
|
size="xs"
|
||||||
|
sx={{ color: 'base.600', _dark: { color: 'base.500' } }}
|
||||||
|
>
|
||||||
{description}
|
{description}
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
|
@ -20,8 +20,8 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
|||||||
justifyContent: 'space-between',
|
justifyContent: 'space-between',
|
||||||
px: 2,
|
px: 2,
|
||||||
py: 1,
|
py: 1,
|
||||||
bg: 'base.300',
|
bg: 'base.100',
|
||||||
_dark: { bg: 'base.700' },
|
_dark: { bg: 'base.900' },
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Tooltip label={nodeId}>
|
<Tooltip label={nodeId}>
|
||||||
@ -30,7 +30,7 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
fontWeight: 600,
|
fontWeight: 600,
|
||||||
color: 'base.900',
|
color: 'base.900',
|
||||||
_dark: { color: 'base.100' },
|
_dark: { color: 'base.200' },
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{title}
|
{title}
|
||||||
|
@ -151,7 +151,6 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
field={field}
|
field={field}
|
||||||
template={template}
|
template={template}
|
||||||
base_models={['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']}
|
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
|||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
borderBottomRadius: 'md',
|
borderBottomRadius: 'md',
|
||||||
py: 2,
|
py: 2,
|
||||||
bg: 'base.200',
|
bg: 'base.150',
|
||||||
_dark: { bg: 'base.800' },
|
_dark: { bg: 'base.800' },
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import 'reactflow/dist/style.css';
|
|
||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import { ReactFlowProvider } from 'reactflow';
|
import { ReactFlowProvider } from 'reactflow';
|
||||||
|
import 'reactflow/dist/style.css';
|
||||||
|
|
||||||
import { Flow } from './Flow';
|
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
import { Flow } from './Flow';
|
||||||
|
|
||||||
const NodeEditor = () => {
|
const NodeEditor = () => {
|
||||||
return (
|
return (
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import { Box, useToken } from '@chakra-ui/react';
|
import { Box, useToken } from '@chakra-ui/react';
|
||||||
import { NODE_MIN_WIDTH } from 'app/constants';
|
import { NODE_MIN_WIDTH } from 'app/constants';
|
||||||
|
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { PropsWithChildren } from 'react';
|
import { PropsWithChildren } from 'react';
|
||||||
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
|
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
|
|
||||||
type NodeWrapperProps = PropsWithChildren & {
|
type NodeWrapperProps = PropsWithChildren & {
|
||||||
selected: boolean;
|
selected: boolean;
|
||||||
|
@ -1,17 +1,36 @@
|
|||||||
import { ButtonGroup } from '@chakra-ui/react';
|
import { ButtonGroup, Tooltip } from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa';
|
import {
|
||||||
|
FaCode,
|
||||||
|
FaExpand,
|
||||||
|
FaMinus,
|
||||||
|
FaPlus,
|
||||||
|
FaInfo,
|
||||||
|
FaMapMarkerAlt,
|
||||||
|
} from 'react-icons/fa';
|
||||||
import { useReactFlow } from 'reactflow';
|
import { useReactFlow } from 'reactflow';
|
||||||
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import {
|
||||||
|
shouldShowGraphOverlayChanged,
|
||||||
|
shouldShowFieldTypeLegendChanged,
|
||||||
|
shouldShowMinimapPanelChanged,
|
||||||
|
} from '../store/nodesSlice';
|
||||||
|
|
||||||
const ViewportControls = () => {
|
const ViewportControls = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const shouldShowGraphOverlay = useAppSelector(
|
const shouldShowGraphOverlay = useAppSelector(
|
||||||
(state) => state.nodes.shouldShowGraphOverlay
|
(state) => state.nodes.shouldShowGraphOverlay
|
||||||
);
|
);
|
||||||
|
const shouldShowFieldTypeLegend = useAppSelector(
|
||||||
|
(state) => state.nodes.shouldShowFieldTypeLegend
|
||||||
|
);
|
||||||
|
const shouldShowMinimapPanel = useAppSelector(
|
||||||
|
(state) => state.nodes.shouldShowMinimapPanel
|
||||||
|
);
|
||||||
|
|
||||||
const handleClickedZoomIn = useCallback(() => {
|
const handleClickedZoomIn = useCallback(() => {
|
||||||
zoomIn();
|
zoomIn();
|
||||||
@ -29,29 +48,64 @@ const ViewportControls = () => {
|
|||||||
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
|
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
|
||||||
}, [shouldShowGraphOverlay, dispatch]);
|
}, [shouldShowGraphOverlay, dispatch]);
|
||||||
|
|
||||||
|
const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
||||||
|
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
||||||
|
}, [shouldShowFieldTypeLegend, dispatch]);
|
||||||
|
|
||||||
|
const handleClickedToggleMiniMapPanel = useCallback(() => {
|
||||||
|
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
|
||||||
|
}, [shouldShowMinimapPanel, dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ButtonGroup isAttached orientation="vertical">
|
<ButtonGroup isAttached orientation="vertical">
|
||||||
<IAIIconButton
|
<Tooltip label={t('nodes.zoomInNodes')}>
|
||||||
onClick={handleClickedZoomIn}
|
<IAIIconButton onClick={handleClickedZoomIn} icon={<FaPlus />} />
|
||||||
aria-label="Zoom In"
|
</Tooltip>
|
||||||
icon={<FaPlus />}
|
<Tooltip label={t('nodes.zoomOutNodes')}>
|
||||||
/>
|
<IAIIconButton onClick={handleClickedZoomOut} icon={<FaMinus />} />
|
||||||
<IAIIconButton
|
</Tooltip>
|
||||||
onClick={handleClickedZoomOut}
|
<Tooltip label={t('nodes.fitViewportNodes')}>
|
||||||
aria-label="Zoom Out"
|
<IAIIconButton onClick={handleClickedFitView} icon={<FaExpand />} />
|
||||||
icon={<FaMinus />}
|
</Tooltip>
|
||||||
/>
|
<Tooltip
|
||||||
<IAIIconButton
|
label={
|
||||||
onClick={handleClickedFitView}
|
shouldShowGraphOverlay
|
||||||
aria-label="Fit to Viewport"
|
? t('nodes.hideGraphNodes')
|
||||||
icon={<FaExpand />}
|
: t('nodes.showGraphNodes')
|
||||||
/>
|
}
|
||||||
<IAIIconButton
|
>
|
||||||
isChecked={shouldShowGraphOverlay}
|
<IAIIconButton
|
||||||
onClick={handleClickedToggleGraphOverlay}
|
isChecked={shouldShowGraphOverlay}
|
||||||
aria-label="Show/Hide Graph"
|
onClick={handleClickedToggleGraphOverlay}
|
||||||
icon={<FaCode />}
|
icon={<FaCode />}
|
||||||
/>
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
<Tooltip
|
||||||
|
label={
|
||||||
|
shouldShowFieldTypeLegend
|
||||||
|
? t('nodes.hideLegendNodes')
|
||||||
|
: t('nodes.showLegendNodes')
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<IAIIconButton
|
||||||
|
isChecked={shouldShowFieldTypeLegend}
|
||||||
|
onClick={handleClickedToggleFieldTypeLegend}
|
||||||
|
icon={<FaInfo />}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
<Tooltip
|
||||||
|
label={
|
||||||
|
shouldShowMinimapPanel
|
||||||
|
? t('nodes.hideMinimapnodes')
|
||||||
|
: t('nodes.showMinimapnodes')
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<IAIIconButton
|
||||||
|
isChecked={shouldShowMinimapPanel}
|
||||||
|
onClick={handleClickedToggleMiniMapPanel}
|
||||||
|
icon={<FaMapMarkerAlt />}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useColorModeValue } from '@chakra-ui/react';
|
import { useColorModeValue } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { MiniMap } from 'reactflow';
|
import { MiniMap } from 'reactflow';
|
||||||
@ -12,6 +14,10 @@ const MinimapPanel = () => {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const shouldShowMinimapPanel = useAppSelector(
|
||||||
|
(state: RootState) => state.nodes.shouldShowMinimapPanel
|
||||||
|
);
|
||||||
|
|
||||||
const nodeColor = useColorModeValue(
|
const nodeColor = useColorModeValue(
|
||||||
'var(--invokeai-colors-accent-300)',
|
'var(--invokeai-colors-accent-300)',
|
||||||
'var(--invokeai-colors-accent-700)'
|
'var(--invokeai-colors-accent-700)'
|
||||||
@ -23,15 +29,19 @@ const MinimapPanel = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<MiniMap
|
<>
|
||||||
nodeStrokeWidth={3}
|
{shouldShowMinimapPanel && (
|
||||||
pannable
|
<MiniMap
|
||||||
zoomable
|
nodeStrokeWidth={3}
|
||||||
nodeBorderRadius={30}
|
pannable
|
||||||
style={miniMapStyle}
|
zoomable
|
||||||
nodeColor={nodeColor}
|
nodeBorderRadius={30}
|
||||||
maskColor={maskColor}
|
style={miniMapStyle}
|
||||||
/>
|
nodeColor={nodeColor}
|
||||||
|
maskColor={maskColor}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -9,10 +9,13 @@ const TopRightPanel = () => {
|
|||||||
const shouldShowGraphOverlay = useAppSelector(
|
const shouldShowGraphOverlay = useAppSelector(
|
||||||
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
||||||
);
|
);
|
||||||
|
const shouldShowFieldTypeLegend = useAppSelector(
|
||||||
|
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Panel position="top-right">
|
<Panel position="top-right">
|
||||||
<FieldTypeLegend />
|
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
|
||||||
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
||||||
</Panel>
|
</Panel>
|
||||||
);
|
);
|
||||||
|
@ -32,6 +32,8 @@ export type NodesState = {
|
|||||||
invocationTemplates: Record<string, InvocationTemplate>;
|
invocationTemplates: Record<string, InvocationTemplate>;
|
||||||
connectionStartParams: OnConnectStartParams | null;
|
connectionStartParams: OnConnectStartParams | null;
|
||||||
shouldShowGraphOverlay: boolean;
|
shouldShowGraphOverlay: boolean;
|
||||||
|
shouldShowFieldTypeLegend: boolean;
|
||||||
|
shouldShowMinimapPanel: boolean;
|
||||||
editorInstance: ReactFlowInstance | undefined;
|
editorInstance: ReactFlowInstance | undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -42,6 +44,8 @@ export const initialNodesState: NodesState = {
|
|||||||
invocationTemplates: {},
|
invocationTemplates: {},
|
||||||
connectionStartParams: null,
|
connectionStartParams: null,
|
||||||
shouldShowGraphOverlay: false,
|
shouldShowGraphOverlay: false,
|
||||||
|
shouldShowFieldTypeLegend: false,
|
||||||
|
shouldShowMinimapPanel: true,
|
||||||
editorInstance: undefined,
|
editorInstance: undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -125,6 +129,15 @@ const nodesSlice = createSlice({
|
|||||||
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldShowGraphOverlay = action.payload;
|
state.shouldShowGraphOverlay = action.payload;
|
||||||
},
|
},
|
||||||
|
shouldShowFieldTypeLegendChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<boolean>
|
||||||
|
) => {
|
||||||
|
state.shouldShowFieldTypeLegend = action.payload;
|
||||||
|
},
|
||||||
|
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldShowMinimapPanel = action.payload;
|
||||||
|
},
|
||||||
nodeTemplatesBuilt: (
|
nodeTemplatesBuilt: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<Record<string, InvocationTemplate>>
|
action: PayloadAction<Record<string, InvocationTemplate>>
|
||||||
@ -161,6 +174,8 @@ export const {
|
|||||||
connectionStarted,
|
connectionStarted,
|
||||||
connectionEnded,
|
connectionEnded,
|
||||||
shouldShowGraphOverlayChanged,
|
shouldShowGraphOverlayChanged,
|
||||||
|
shouldShowFieldTypeLegendChanged,
|
||||||
|
shouldShowMinimapPanelChanged,
|
||||||
nodeTemplatesBuilt,
|
nodeTemplatesBuilt,
|
||||||
nodeEditorReset,
|
nodeEditorReset,
|
||||||
imageCollectionFieldValueChanged,
|
imageCollectionFieldValueChanged,
|
||||||
|
@ -36,7 +36,7 @@ const ParamMainModelSelect = () => {
|
|||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(mainModels.entities, (model, id) => {
|
forEach(mainModels.entities, (model, id) => {
|
||||||
if (!model) {
|
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ import {
|
|||||||
ModalOverlay,
|
ModalOverlay,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { cloneElement, ReactElement } from 'react';
|
import { ReactElement, cloneElement } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import HotkeysModalItem from './HotkeysModalItem';
|
import HotkeysModalItem from './HotkeysModalItem';
|
||||||
|
|
||||||
@ -65,11 +65,6 @@ export default function HotkeysModal({ children }: HotkeysModalProps) {
|
|||||||
desc: t('hotkeys.pinOptions.desc'),
|
desc: t('hotkeys.pinOptions.desc'),
|
||||||
hotkey: 'Shift+O',
|
hotkey: 'Shift+O',
|
||||||
},
|
},
|
||||||
{
|
|
||||||
title: t('hotkeys.toggleViewer.title'),
|
|
||||||
desc: t('hotkeys.toggleViewer.desc'),
|
|
||||||
hotkey: 'Z',
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
title: t('hotkeys.toggleGallery.title'),
|
title: t('hotkeys.toggleGallery.title'),
|
||||||
desc: t('hotkeys.toggleGallery.desc'),
|
desc: t('hotkeys.toggleGallery.desc'),
|
||||||
@ -85,12 +80,6 @@ export default function HotkeysModal({ children }: HotkeysModalProps) {
|
|||||||
desc: t('hotkeys.changeTabs.desc'),
|
desc: t('hotkeys.changeTabs.desc'),
|
||||||
hotkey: '1-5',
|
hotkey: '1-5',
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
|
||||||
title: t('hotkeys.consoleToggle.title'),
|
|
||||||
desc: t('hotkeys.consoleToggle.desc'),
|
|
||||||
hotkey: '`',
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
const generalHotkeys = [
|
const generalHotkeys = [
|
||||||
@ -109,11 +98,6 @@ export default function HotkeysModal({ children }: HotkeysModalProps) {
|
|||||||
desc: t('hotkeys.setParameters.desc'),
|
desc: t('hotkeys.setParameters.desc'),
|
||||||
hotkey: 'A',
|
hotkey: 'A',
|
||||||
},
|
},
|
||||||
{
|
|
||||||
title: t('hotkeys.restoreFaces.title'),
|
|
||||||
desc: t('hotkeys.restoreFaces.desc'),
|
|
||||||
hotkey: 'Shift+R',
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
title: t('hotkeys.upscale.title'),
|
title: t('hotkeys.upscale.title'),
|
||||||
desc: t('hotkeys.upscale.desc'),
|
desc: t('hotkeys.upscale.desc'),
|
||||||
|
@ -183,7 +183,7 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
|
|||||||
>
|
>
|
||||||
<ModalOverlay />
|
<ModalOverlay />
|
||||||
<ModalContent>
|
<ModalContent>
|
||||||
<ModalHeader>{t('common.settingsLabel')}</ModalHeader>
|
<ModalHeader bg="none">{t('common.settingsLabel')}</ModalHeader>
|
||||||
<ModalCloseButton />
|
<ModalCloseButton />
|
||||||
<ModalBody>
|
<ModalBody>
|
||||||
<Flex sx={{ gap: 4, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 4, flexDirection: 'column' }}>
|
||||||
@ -331,12 +331,15 @@ export default SettingsModal;
|
|||||||
const StyledFlex = (props: PropsWithChildren) => {
|
const StyledFlex = (props: PropsWithChildren) => {
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
layerStyle="second"
|
|
||||||
sx={{
|
sx={{
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
gap: 2,
|
gap: 2,
|
||||||
p: 4,
|
p: 4,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
|
bg: 'base.100',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.900',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{props.children}
|
{props.children}
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import { UseToastOptions } from '@chakra-ui/react';
|
import { UseToastOptions } from '@chakra-ui/react';
|
||||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
|
||||||
|
|
||||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||||
import { TFuncKey, t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { LogLevelName } from 'roarr';
|
import { LogLevelName } from 'roarr';
|
||||||
import { imageUploaded } from 'services/api/thunks/image';
|
import { imageUploaded } from 'services/api/thunks/image';
|
||||||
import {
|
import {
|
||||||
@ -44,8 +43,6 @@ export interface SystemState {
|
|||||||
isCancelable: boolean;
|
isCancelable: boolean;
|
||||||
enableImageDebugging: boolean;
|
enableImageDebugging: boolean;
|
||||||
toastQueue: UseToastOptions[];
|
toastQueue: UseToastOptions[];
|
||||||
searchFolder: string | null;
|
|
||||||
foundModels: InvokeAI.FoundModel[] | null;
|
|
||||||
/**
|
/**
|
||||||
* The current progress image
|
* The current progress image
|
||||||
*/
|
*/
|
||||||
@ -79,7 +76,7 @@ export interface SystemState {
|
|||||||
*/
|
*/
|
||||||
consoleLogLevel: InvokeLogLevel;
|
consoleLogLevel: InvokeLogLevel;
|
||||||
shouldLogToConsole: boolean;
|
shouldLogToConsole: boolean;
|
||||||
statusTranslationKey: TFuncKey;
|
statusTranslationKey: any;
|
||||||
/**
|
/**
|
||||||
* When a session is canceled, its ID is stored here until a new session is created.
|
* When a session is canceled, its ID is stored here until a new session is created.
|
||||||
*/
|
*/
|
||||||
@ -106,8 +103,6 @@ export const initialSystemState: SystemState = {
|
|||||||
isCancelable: true,
|
isCancelable: true,
|
||||||
enableImageDebugging: false,
|
enableImageDebugging: false,
|
||||||
toastQueue: [],
|
toastQueue: [],
|
||||||
searchFolder: null,
|
|
||||||
foundModels: null,
|
|
||||||
progressImage: null,
|
progressImage: null,
|
||||||
shouldAntialiasProgressImage: false,
|
shouldAntialiasProgressImage: false,
|
||||||
sessionId: null,
|
sessionId: null,
|
||||||
@ -132,7 +127,7 @@ export const systemSlice = createSlice({
|
|||||||
setIsProcessing: (state, action: PayloadAction<boolean>) => {
|
setIsProcessing: (state, action: PayloadAction<boolean>) => {
|
||||||
state.isProcessing = action.payload;
|
state.isProcessing = action.payload;
|
||||||
},
|
},
|
||||||
setCurrentStatus: (state, action: PayloadAction<TFuncKey>) => {
|
setCurrentStatus: (state, action: any) => {
|
||||||
state.statusTranslationKey = action.payload;
|
state.statusTranslationKey = action.payload;
|
||||||
},
|
},
|
||||||
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
|
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
|
||||||
@ -153,15 +148,6 @@ export const systemSlice = createSlice({
|
|||||||
clearToastQueue: (state) => {
|
clearToastQueue: (state) => {
|
||||||
state.toastQueue = [];
|
state.toastQueue = [];
|
||||||
},
|
},
|
||||||
setSearchFolder: (state, action: PayloadAction<string | null>) => {
|
|
||||||
state.searchFolder = action.payload;
|
|
||||||
},
|
|
||||||
setFoundModels: (
|
|
||||||
state,
|
|
||||||
action: PayloadAction<InvokeAI.FoundModel[] | null>
|
|
||||||
) => {
|
|
||||||
state.foundModels = action.payload;
|
|
||||||
},
|
|
||||||
/**
|
/**
|
||||||
* A cancel was scheduled
|
* A cancel was scheduled
|
||||||
*/
|
*/
|
||||||
@ -426,8 +412,6 @@ export const {
|
|||||||
setEnableImageDebugging,
|
setEnableImageDebugging,
|
||||||
addToast,
|
addToast,
|
||||||
clearToastQueue,
|
clearToastQueue,
|
||||||
setSearchFolder,
|
|
||||||
setFoundModels,
|
|
||||||
cancelScheduled,
|
cancelScheduled,
|
||||||
scheduledCancelAborted,
|
scheduledCancelAborted,
|
||||||
cancelTypeChanged,
|
cancelTypeChanged,
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
|
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
import { ReactNode, memo } from 'react';
|
import { ReactNode, memo } from 'react';
|
||||||
import AddModelsPanel from './subpanels/AddModelsPanel';
|
import ImportModelsPanel from './subpanels/ImportModelsPanel';
|
||||||
import MergeModelsPanel from './subpanels/MergeModelsPanel';
|
import MergeModelsPanel from './subpanels/MergeModelsPanel';
|
||||||
import ModelManagerPanel from './subpanels/ModelManagerPanel';
|
import ModelManagerPanel from './subpanels/ModelManagerPanel';
|
||||||
|
|
||||||
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels';
|
type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels';
|
||||||
|
|
||||||
type ModelManagerTabInfo = {
|
type ModelManagerTabInfo = {
|
||||||
id: ModelManagerTabName;
|
id: ModelManagerTabName;
|
||||||
@ -20,9 +20,9 @@ const tabs: ModelManagerTabInfo[] = [
|
|||||||
content: <ModelManagerPanel />,
|
content: <ModelManagerPanel />,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'addModels',
|
id: 'importModels',
|
||||||
label: i18n.t('modelManager.addModel'),
|
label: i18n.t('modelManager.importModels'),
|
||||||
content: <AddModelsPanel />,
|
content: <ImportModelsPanel />,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'mergeModels',
|
id: 'mergeModels',
|
||||||
@ -46,7 +46,7 @@ const ModelManagerTab = () => {
|
|||||||
</Tab>
|
</Tab>
|
||||||
))}
|
))}
|
||||||
</TabList>
|
</TabList>
|
||||||
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
|
<TabPanels sx={{ w: 'full', h: 'full' }}>
|
||||||
{tabs.map((tab) => (
|
{tabs.map((tab) => (
|
||||||
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
|
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
|
||||||
{tab.content}
|
{tab.content}
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
|
type ModelManagerState = {
|
||||||
|
searchFolder: string | null;
|
||||||
|
advancedAddScanModel: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const initialModelManagerState: ModelManagerState = {
|
||||||
|
searchFolder: null,
|
||||||
|
advancedAddScanModel: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelManagerSlice = createSlice({
|
||||||
|
name: 'modelmanager',
|
||||||
|
initialState: initialModelManagerState,
|
||||||
|
reducers: {
|
||||||
|
setSearchFolder: (state, action: PayloadAction<string | null>) => {
|
||||||
|
state.searchFolder = action.payload;
|
||||||
|
},
|
||||||
|
setAdvancedAddScanModel: (state, action: PayloadAction<string | null>) => {
|
||||||
|
state.advancedAddScanModel = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { setSearchFolder, setAdvancedAddScanModel } =
|
||||||
|
modelManagerSlice.actions;
|
||||||
|
|
||||||
|
export default modelManagerSlice.reducer;
|
@ -0,0 +1,3 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
|
export const modelmanagerSelector = (state: RootState) => state.modelmanager;
|
@ -1,43 +0,0 @@
|
|||||||
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
|
|
||||||
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
|
|
||||||
|
|
||||||
export default function AddModelsPanel() {
|
|
||||||
const addNewModelUIOption = useAppSelector(
|
|
||||||
(state: RootState) => state.ui.addNewModelUIOption
|
|
||||||
);
|
|
||||||
|
|
||||||
const { colorMode } = useColorMode();
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" gap={4}>
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<IAIButton
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
|
||||||
isChecked={addNewModelUIOption == 'ckpt'}
|
|
||||||
>
|
|
||||||
{t('modelManager.addCheckpointModel')}
|
|
||||||
</IAIButton>
|
|
||||||
<IAIButton
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
|
||||||
isChecked={addNewModelUIOption == 'diffusers'}
|
|
||||||
>
|
|
||||||
{t('modelManager.addDiffuserModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Divider />
|
|
||||||
|
|
||||||
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
|
||||||
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,337 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormErrorMessage,
|
|
||||||
FormHelperText,
|
|
||||||
FormLabel,
|
|
||||||
HStack,
|
|
||||||
Text,
|
|
||||||
VStack,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
|
||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
|
||||||
import React from 'react';
|
|
||||||
|
|
||||||
// import { addNewModel } from 'app/socketio/actions';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
|
|
||||||
import { Field, Formik } from 'formik';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
|
||||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
|
||||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
|
||||||
import type { FieldInputProps, FormikProps } from 'formik';
|
|
||||||
import SearchModels from './SearchModels';
|
|
||||||
|
|
||||||
const MIN_MODEL_SIZE = 64;
|
|
||||||
const MAX_MODEL_SIZE = 2048;
|
|
||||||
|
|
||||||
export default function AddCheckpointModel() {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
function hasWhiteSpace(s: string) {
|
|
||||||
return /\s/.test(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
function baseValidation(value: string) {
|
|
||||||
let error;
|
|
||||||
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
|
|
||||||
return error;
|
|
||||||
}
|
|
||||||
|
|
||||||
const addModelFormValues: InvokeModelConfigProps = {
|
|
||||||
name: '',
|
|
||||||
description: '',
|
|
||||||
config: 'configs/stable-diffusion/v1-inference.yaml',
|
|
||||||
weights: '',
|
|
||||||
vae: '',
|
|
||||||
width: 512,
|
|
||||||
height: 512,
|
|
||||||
format: 'ckpt',
|
|
||||||
default: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
|
||||||
dispatch(addNewModel(values));
|
|
||||||
dispatch(setAddNewModelUIOption(null));
|
|
||||||
};
|
|
||||||
|
|
||||||
const [addManually, setAddmanually] = React.useState<boolean>(false);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<VStack gap={2} alignItems="flex-start">
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<IAISimpleCheckbox
|
|
||||||
isChecked={!addManually}
|
|
||||||
label={t('modelManager.scanForModels')}
|
|
||||||
onChange={() => setAddmanually(!addManually)}
|
|
||||||
/>
|
|
||||||
<IAISimpleCheckbox
|
|
||||||
label={t('modelManager.addManually')}
|
|
||||||
isChecked={addManually}
|
|
||||||
onChange={() => setAddmanually(!addManually)}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
{addManually ? (
|
|
||||||
<Formik
|
|
||||||
initialValues={addModelFormValues}
|
|
||||||
onSubmit={addModelFormSubmitHandler}
|
|
||||||
>
|
|
||||||
{({ handleSubmit, errors, touched }) => (
|
|
||||||
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
|
|
||||||
<VStack rowGap={2}>
|
|
||||||
<Text fontSize={20} fontWeight="bold" alignSelf="start">
|
|
||||||
{t('modelManager.manual')}
|
|
||||||
</Text>
|
|
||||||
{/* Name */}
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.name && touched.name}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="name" fontSize="sm">
|
|
||||||
{t('modelManager.name')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="name"
|
|
||||||
name="name"
|
|
||||||
type="text"
|
|
||||||
validate={baseValidation}
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.name && touched.name ? (
|
|
||||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.nameValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
{/* Description */}
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.description && touched.description}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="description" fontSize="sm">
|
|
||||||
{t('modelManager.description')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="description"
|
|
||||||
name="description"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.description && touched.description ? (
|
|
||||||
<FormErrorMessage>
|
|
||||||
{errors.description}
|
|
||||||
</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.descriptionValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
{/* Config */}
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.config && touched.config}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelManager.config')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="config"
|
|
||||||
name="config"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.config && touched.config ? (
|
|
||||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.configValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
{/* Weights */}
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.weights && touched.weights}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelManager.modelLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="weights"
|
|
||||||
name="weights"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.weights && touched.weights ? (
|
|
||||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.modelLocationValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
{/* VAE */}
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
|
||||||
<FormLabel htmlFor="vae" fontSize="sm">
|
|
||||||
{t('modelManager.vaeLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae"
|
|
||||||
name="vae"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.vae && touched.vae ? (
|
|
||||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.vaeLocationValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
<HStack width="100%">
|
|
||||||
{/* Width */}
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
|
||||||
<FormLabel htmlFor="width" fontSize="sm">
|
|
||||||
{t('modelManager.width')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field id="width" name="width">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="width"
|
|
||||||
name="width"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
step={64}
|
|
||||||
value={form.values.width}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(field.name, Number(value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.width && touched.width ? (
|
|
||||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.widthValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
{/* Height */}
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
|
||||||
<FormLabel htmlFor="height" fontSize="sm">
|
|
||||||
{t('modelManager.height')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field id="height" name="height">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="height"
|
|
||||||
name="height"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
step={64}
|
|
||||||
value={form.values.height}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(field.name, Number(value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.height && touched.height ? (
|
|
||||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.heightValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
</HStack>
|
|
||||||
|
|
||||||
<IAIButton
|
|
||||||
type="submit"
|
|
||||||
className="modal-close-btn"
|
|
||||||
isLoading={isProcessing}
|
|
||||||
>
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</VStack>
|
|
||||||
</IAIForm>
|
|
||||||
)}
|
|
||||||
</Formik>
|
|
||||||
) : (
|
|
||||||
<SearchModels />
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,259 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormErrorMessage,
|
|
||||||
FormHelperText,
|
|
||||||
FormLabel,
|
|
||||||
Text,
|
|
||||||
VStack,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
|
|
||||||
// import { addNewModel } from 'app/socketio/actions';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
|
||||||
import { Field, Formik } from 'formik';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
|
||||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
|
||||||
|
|
||||||
export default function AddDiffusersModel() {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
function hasWhiteSpace(s: string) {
|
|
||||||
return /\s/.test(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
function baseValidation(value: string) {
|
|
||||||
let error;
|
|
||||||
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
|
|
||||||
return error;
|
|
||||||
}
|
|
||||||
|
|
||||||
const addModelFormValues: InvokeDiffusersModelConfigProps = {
|
|
||||||
name: '',
|
|
||||||
description: '',
|
|
||||||
repo_id: '',
|
|
||||||
path: '',
|
|
||||||
format: 'diffusers',
|
|
||||||
default: false,
|
|
||||||
vae: {
|
|
||||||
repo_id: '',
|
|
||||||
path: '',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
const addModelFormSubmitHandler = (
|
|
||||||
values: InvokeDiffusersModelConfigProps
|
|
||||||
) => {
|
|
||||||
const diffusersModelToAdd = values;
|
|
||||||
|
|
||||||
if (values.path === '') delete diffusersModelToAdd.path;
|
|
||||||
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
|
|
||||||
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
|
|
||||||
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
|
|
||||||
|
|
||||||
dispatch(addNewModel(diffusersModelToAdd));
|
|
||||||
dispatch(setAddNewModelUIOption(null));
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
|
|
||||||
<Formik
|
|
||||||
initialValues={addModelFormValues}
|
|
||||||
onSubmit={addModelFormSubmitHandler}
|
|
||||||
>
|
|
||||||
{({ handleSubmit, errors, touched }) => (
|
|
||||||
<IAIForm onSubmit={handleSubmit} w="full">
|
|
||||||
<VStack rowGap={2}>
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
{/* Name */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.name && touched.name}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="name" fontSize="sm">
|
|
||||||
{t('modelManager.name')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="name"
|
|
||||||
name="name"
|
|
||||||
type="text"
|
|
||||||
validate={baseValidation}
|
|
||||||
isRequired
|
|
||||||
/>
|
|
||||||
{!!errors.name && touched.name ? (
|
|
||||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.nameValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
{/* Description */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.description && touched.description}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="description" fontSize="sm">
|
|
||||||
{t('modelManager.description')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="description"
|
|
||||||
name="description"
|
|
||||||
type="text"
|
|
||||||
isRequired
|
|
||||||
/>
|
|
||||||
{!!errors.description && touched.description ? (
|
|
||||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.descriptionValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
<Text fontWeight="bold" fontSize="sm">
|
|
||||||
{t('modelManager.formMessageDiffusersModelLocation')}
|
|
||||||
</Text>
|
|
||||||
<Text
|
|
||||||
sx={{
|
|
||||||
fontSize: 'sm',
|
|
||||||
fontStyle: 'italic',
|
|
||||||
}}
|
|
||||||
variant="subtext"
|
|
||||||
>
|
|
||||||
{t('modelManager.formMessageDiffusersModelLocationDesc')}
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
{/* Path */}
|
|
||||||
<FormControl isInvalid={!!errors.path && touched.path}>
|
|
||||||
<FormLabel htmlFor="path" fontSize="sm">
|
|
||||||
{t('modelManager.modelLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field as={IAIInput} id="path" name="path" type="text" />
|
|
||||||
{!!errors.path && touched.path ? (
|
|
||||||
<FormErrorMessage>{errors.path}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.modelLocationValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Repo ID */}
|
|
||||||
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
|
||||||
<FormLabel htmlFor="repo_id" fontSize="sm">
|
|
||||||
{t('modelManager.repo_id')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="repo_id"
|
|
||||||
name="repo_id"
|
|
||||||
type="text"
|
|
||||||
/>
|
|
||||||
{!!errors.repo_id && touched.repo_id ? (
|
|
||||||
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.repoIDValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
<IAIFormItemWrapper>
|
|
||||||
{/* VAE Path */}
|
|
||||||
<Text fontWeight="bold">
|
|
||||||
{t('modelManager.formMessageDiffusersVAELocation')}
|
|
||||||
</Text>
|
|
||||||
<Text
|
|
||||||
sx={{
|
|
||||||
fontSize: 'sm',
|
|
||||||
fontStyle: 'italic',
|
|
||||||
}}
|
|
||||||
variant="subtext"
|
|
||||||
>
|
|
||||||
{t('modelManager.formMessageDiffusersVAELocationDesc')}
|
|
||||||
</Text>
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="vae.path" fontSize="sm">
|
|
||||||
{t('modelManager.vaeLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae.path"
|
|
||||||
name="vae.path"
|
|
||||||
type="text"
|
|
||||||
/>
|
|
||||||
{!!errors.vae?.path && touched.vae?.path ? (
|
|
||||||
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.vaeLocationValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE Repo ID */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
|
||||||
{t('modelManager.vaeRepoID')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae.repo_id"
|
|
||||||
name="vae.repo_id"
|
|
||||||
type="text"
|
|
||||||
/>
|
|
||||||
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
|
||||||
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelManager.vaeRepoIDValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</IAIFormItemWrapper>
|
|
||||||
|
|
||||||
<IAIButton type="submit" isLoading={isProcessing}>
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</VStack>
|
|
||||||
</IAIForm>
|
|
||||||
)}
|
|
||||||
</Formik>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
@ -0,0 +1,48 @@
|
|||||||
|
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import AdvancedAddModels from './AdvancedAddModels';
|
||||||
|
import SimpleAddModels from './SimpleAddModels';
|
||||||
|
|
||||||
|
export default function AddModels() {
|
||||||
|
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>(
|
||||||
|
'simple'
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
flexDirection="column"
|
||||||
|
width="100%"
|
||||||
|
overflow="scroll"
|
||||||
|
maxHeight={window.innerHeight - 250}
|
||||||
|
gap={4}
|
||||||
|
>
|
||||||
|
<ButtonGroup isAttached>
|
||||||
|
<IAIButton
|
||||||
|
size="sm"
|
||||||
|
isChecked={addModelMode == 'simple'}
|
||||||
|
onClick={() => setAddModelMode('simple')}
|
||||||
|
>
|
||||||
|
Simple
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton
|
||||||
|
size="sm"
|
||||||
|
isChecked={addModelMode == 'advanced'}
|
||||||
|
onClick={() => setAddModelMode('advanced')}
|
||||||
|
>
|
||||||
|
Advanced
|
||||||
|
</IAIButton>
|
||||||
|
</ButtonGroup>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
p: 4,
|
||||||
|
borderRadius: 4,
|
||||||
|
background: 'base.200',
|
||||||
|
_dark: { background: 'base.800' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{addModelMode === 'simple' && <SimpleAddModels />}
|
||||||
|
{addModelMode === 'advanced' && <AdvancedAddModels />}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,143 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { useForm } from '@mantine/form';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
import { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||||
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
|
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
|
||||||
|
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||||
|
|
||||||
|
type AdvancedAddCheckpointProps = {
|
||||||
|
model_path?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function AdvancedAddCheckpoint(
|
||||||
|
props: AdvancedAddCheckpointProps
|
||||||
|
) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { model_path } = props;
|
||||||
|
|
||||||
|
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
|
||||||
|
initialValues: {
|
||||||
|
model_name: model_path
|
||||||
|
? model_path.split('\\').splice(-1)[0].split('.')[0]
|
||||||
|
: '',
|
||||||
|
base_model: 'sd-1',
|
||||||
|
model_type: 'main',
|
||||||
|
path: model_path ? model_path : '',
|
||||||
|
description: '',
|
||||||
|
model_format: 'checkpoint',
|
||||||
|
error: undefined,
|
||||||
|
vae: '',
|
||||||
|
variant: 'normal',
|
||||||
|
config: 'configs\\stable-diffusion\\v1-inference.yaml',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const [addMainModel] = useAddMainModelsMutation();
|
||||||
|
|
||||||
|
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||||
|
|
||||||
|
const advancedAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
|
||||||
|
addMainModel({
|
||||||
|
body: values,
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `Model Added: ${values.model_name}`,
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
advancedAddCheckpointForm.reset();
|
||||||
|
|
||||||
|
// Close Advanced Panel in Scan Models tab
|
||||||
|
if (model_path) {
|
||||||
|
dispatch(setAdvancedAddScanModel(null));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Model Add Failed',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form
|
||||||
|
onSubmit={advancedAddCheckpointForm.onSubmit((v) =>
|
||||||
|
advancedAddCheckpointFormHandler(v)
|
||||||
|
)}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" gap={2}>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Model Name"
|
||||||
|
required
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('model_name')}
|
||||||
|
/>
|
||||||
|
<BaseModelSelect
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('base_model')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Model Location"
|
||||||
|
required
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('path')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Description"
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('description')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="VAE Location"
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('vae')}
|
||||||
|
/>
|
||||||
|
<ModelVariantSelect
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('variant')}
|
||||||
|
/>
|
||||||
|
<Flex flexDirection="column" width="100%" gap={2}>
|
||||||
|
{!useCustomConfig ? (
|
||||||
|
<CheckpointConfigsSelect
|
||||||
|
required
|
||||||
|
width="100%"
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('config')}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IAIMantineTextInput
|
||||||
|
required
|
||||||
|
label="Custom Config File Location"
|
||||||
|
{...advancedAddCheckpointForm.getInputProps('config')}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<IAISimpleCheckbox
|
||||||
|
isChecked={useCustomConfig}
|
||||||
|
onChange={() => setUseCustomConfig(!useCustomConfig)}
|
||||||
|
label="Use Custom Config"
|
||||||
|
/>
|
||||||
|
<IAIButton mt={2} type="submit">
|
||||||
|
{t('modelManager.addModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,113 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { useForm } from '@mantine/form';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
import { DiffusersModelConfig } from 'services/api/types';
|
||||||
|
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||||
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
|
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||||
|
|
||||||
|
type AdvancedAddDiffusersProps = {
|
||||||
|
model_path?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { model_path } = props;
|
||||||
|
|
||||||
|
const [addMainModel] = useAddMainModelsMutation();
|
||||||
|
|
||||||
|
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
|
||||||
|
initialValues: {
|
||||||
|
model_name: model_path ? model_path.split('\\').splice(-1)[0] : '',
|
||||||
|
base_model: 'sd-1',
|
||||||
|
model_type: 'main',
|
||||||
|
path: model_path ? model_path : '',
|
||||||
|
description: '',
|
||||||
|
model_format: 'diffusers',
|
||||||
|
error: undefined,
|
||||||
|
vae: '',
|
||||||
|
variant: 'normal',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const advancedAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
|
||||||
|
addMainModel({
|
||||||
|
body: values,
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `Model Added: ${values.model_name}`,
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
advancedAddDiffusersForm.reset();
|
||||||
|
// Close Advanced Panel in Scan Models tab
|
||||||
|
if (model_path) {
|
||||||
|
dispatch(setAdvancedAddScanModel(null));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Model Add Failed',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form
|
||||||
|
onSubmit={advancedAddDiffusersForm.onSubmit((v) =>
|
||||||
|
advancedAddDiffusersFormHandler(v)
|
||||||
|
)}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" gap={2}>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
required
|
||||||
|
label="Model Name"
|
||||||
|
{...advancedAddDiffusersForm.getInputProps('model_name')}
|
||||||
|
/>
|
||||||
|
<BaseModelSelect
|
||||||
|
{...advancedAddDiffusersForm.getInputProps('base_model')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
required
|
||||||
|
label="Model Location"
|
||||||
|
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
|
||||||
|
{...advancedAddDiffusersForm.getInputProps('path')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Description"
|
||||||
|
{...advancedAddDiffusersForm.getInputProps('description')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="VAE Location"
|
||||||
|
{...advancedAddDiffusersForm.getInputProps('vae')}
|
||||||
|
/>
|
||||||
|
<ModelVariantSelect
|
||||||
|
{...advancedAddDiffusersForm.getInputProps('variant')}
|
||||||
|
/>
|
||||||
|
<IAIButton mt={2} type="submit">
|
||||||
|
{t('modelManager.addModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
||||||
|
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
|
||||||
|
|
||||||
|
export const advancedAddModeData: SelectItem[] = [
|
||||||
|
{ label: 'Diffusers', value: 'diffusers' },
|
||||||
|
{ label: 'Checkpoint / Safetensors', value: 'checkpoint' },
|
||||||
|
];
|
||||||
|
|
||||||
|
export type ManualAddMode = 'diffusers' | 'checkpoint';
|
||||||
|
|
||||||
|
export default function AdvancedAddModels() {
|
||||||
|
const [advancedAddMode, setAdvancedAddMode] =
|
||||||
|
useState<ManualAddMode>('diffusers');
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" gap={4} width="100%">
|
||||||
|
<IAIMantineSelect
|
||||||
|
label="Model Type"
|
||||||
|
value={advancedAddMode}
|
||||||
|
data={advancedAddModeData}
|
||||||
|
onChange={(v) => {
|
||||||
|
if (!v) return;
|
||||||
|
setAdvancedAddMode(v as ManualAddMode);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
p: 4,
|
||||||
|
borderRadius: 4,
|
||||||
|
bg: 'base.300',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.850',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{advancedAddMode === 'diffusers' && <AdvancedAddDiffusers />}
|
||||||
|
{advancedAddMode === 'checkpoint' && <AdvancedAddCheckpoint />}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,253 @@
|
|||||||
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import IAIScrollArea from 'common/components/IAIScrollArea';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { difference, forEach, intersection, map, values } from 'lodash-es';
|
||||||
|
import { ChangeEvent, MouseEvent, useCallback, useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import {
|
||||||
|
SearchFolderResponse,
|
||||||
|
useGetMainModelsQuery,
|
||||||
|
useGetModelsInFolderQuery,
|
||||||
|
useImportMainModelsMutation,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||||
|
|
||||||
|
export default function FoundModelsList() {
|
||||||
|
const searchFolder = useAppSelector(
|
||||||
|
(state: RootState) => state.modelmanager.searchFolder
|
||||||
|
);
|
||||||
|
const [nameFilter, setNameFilter] = useState<string>('');
|
||||||
|
|
||||||
|
// Get paths of models that are already installed
|
||||||
|
const { data: installedModels } = useGetMainModelsQuery();
|
||||||
|
|
||||||
|
// Get all model paths from a given directory
|
||||||
|
const { foundModels, alreadyInstalled, filteredModels } =
|
||||||
|
useGetModelsInFolderQuery(
|
||||||
|
{
|
||||||
|
search_path: searchFolder ? searchFolder : '',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
selectFromResult: ({ data }) => {
|
||||||
|
const installedModelValues = values(installedModels?.entities);
|
||||||
|
const installedModelPaths = map(installedModelValues, 'path');
|
||||||
|
// Only select models those that aren't already installed to Invoke
|
||||||
|
const notInstalledModels = difference(data, installedModelPaths);
|
||||||
|
const alreadyInstalled = intersection(data, installedModelPaths);
|
||||||
|
return {
|
||||||
|
foundModels: data,
|
||||||
|
alreadyInstalled: foundModelsFilter(alreadyInstalled, nameFilter),
|
||||||
|
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const quickAddHandler = useCallback(
|
||||||
|
(e: MouseEvent<HTMLButtonElement>) => {
|
||||||
|
const model_name = e.currentTarget.id.split('\\').splice(-1)[0];
|
||||||
|
importMainModel({
|
||||||
|
body: {
|
||||||
|
location: e.currentTarget.id,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `Added Model: ${model_name}`,
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Faile To Add Model',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, importMainModel]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
setNameFilter(e.target.value);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const renderModels = ({
|
||||||
|
models,
|
||||||
|
showActions = true,
|
||||||
|
}: {
|
||||||
|
models: string[];
|
||||||
|
showActions?: boolean;
|
||||||
|
}) => {
|
||||||
|
return models.map((model) => {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
p: 4,
|
||||||
|
gap: 4,
|
||||||
|
alignItems: 'center',
|
||||||
|
borderRadius: 4,
|
||||||
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.800',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
key={model}
|
||||||
|
>
|
||||||
|
<Flex w="100%" sx={{ flexDirection: 'column', minW: '25%' }}>
|
||||||
|
<Text sx={{ fontWeight: 600 }}>
|
||||||
|
{model.split('\\').slice(-1)[0]}
|
||||||
|
</Text>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: 'sm',
|
||||||
|
color: 'base.600',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.400',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{model}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
{showActions ? (
|
||||||
|
<Flex gap={2}>
|
||||||
|
<IAIButton
|
||||||
|
id={model}
|
||||||
|
onClick={quickAddHandler}
|
||||||
|
isLoading={isLoading}
|
||||||
|
>
|
||||||
|
Quick Add
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton
|
||||||
|
onClick={() => dispatch(setAdvancedAddScanModel(model))}
|
||||||
|
isLoading={isLoading}
|
||||||
|
>
|
||||||
|
Advanced
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
) : (
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontWeight: 600,
|
||||||
|
p: 2,
|
||||||
|
borderRadius: 4,
|
||||||
|
color: 'accent.50',
|
||||||
|
bg: 'accent.400',
|
||||||
|
_dark: {
|
||||||
|
color: 'accent.100',
|
||||||
|
bg: 'accent.600',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Installed
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderFoundModels = () => {
|
||||||
|
if (!searchFolder) return;
|
||||||
|
|
||||||
|
if (!foundModels || foundModels.length === 0) {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
height: 96,
|
||||||
|
userSelect: 'none',
|
||||||
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.900',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text variant="subtext">No Models Found</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
gap: 2,
|
||||||
|
w: '100%',
|
||||||
|
minW: '50%',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAIInput
|
||||||
|
onChange={handleSearchFilter}
|
||||||
|
label={t('modelManager.search')}
|
||||||
|
labelPos="side"
|
||||||
|
/>
|
||||||
|
<Flex p={2} gap={2}>
|
||||||
|
<Text sx={{ fontWeight: 600 }}>
|
||||||
|
Models Found: {foundModels.length}
|
||||||
|
</Text>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontWeight: 600,
|
||||||
|
color: 'accent.500',
|
||||||
|
_dark: {
|
||||||
|
color: 'accent.200',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Not Installed: {filteredModels.length}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<IAIScrollArea offsetScrollbars>
|
||||||
|
<Flex gap={2} flexDirection="column">
|
||||||
|
{renderModels({ models: filteredModels })}
|
||||||
|
{renderModels({ models: alreadyInstalled, showActions: false })}
|
||||||
|
</Flex>
|
||||||
|
</IAIScrollArea>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return renderFoundModels();
|
||||||
|
}
|
||||||
|
|
||||||
|
const foundModelsFilter = (
|
||||||
|
data: SearchFolderResponse | undefined,
|
||||||
|
nameFilter: string
|
||||||
|
) => {
|
||||||
|
const filteredModels: SearchFolderResponse = [];
|
||||||
|
forEach(data, (model) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model.includes(nameFilter)) {
|
||||||
|
filteredModels.push(model);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return filteredModels;
|
||||||
|
};
|
@ -0,0 +1,97 @@
|
|||||||
|
import { Box, Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { motion } from 'framer-motion';
|
||||||
|
import { useEffect, useState } from 'react';
|
||||||
|
import { FaTimes } from 'react-icons/fa';
|
||||||
|
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||||
|
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
||||||
|
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
|
||||||
|
import { ManualAddMode, advancedAddModeData } from './AdvancedAddModels';
|
||||||
|
|
||||||
|
export default function ScanAdvancedAddModels() {
|
||||||
|
const advancedAddScanModel = useAppSelector(
|
||||||
|
(state: RootState) => state.modelmanager.advancedAddScanModel
|
||||||
|
);
|
||||||
|
|
||||||
|
const [advancedAddMode, setAdvancedAddMode] =
|
||||||
|
useState<ManualAddMode>('diffusers');
|
||||||
|
|
||||||
|
const [isCheckpoint, setIsCheckpoint] = useState<boolean>(true);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
advancedAddScanModel &&
|
||||||
|
['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) =>
|
||||||
|
advancedAddScanModel.endsWith(ext)
|
||||||
|
)
|
||||||
|
? setAdvancedAddMode('checkpoint')
|
||||||
|
: setAdvancedAddMode('diffusers');
|
||||||
|
}, [advancedAddScanModel, setAdvancedAddMode, isCheckpoint]);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
return (
|
||||||
|
advancedAddScanModel && (
|
||||||
|
<Box
|
||||||
|
as={motion.div}
|
||||||
|
initial={{ x: -100, opacity: 0 }}
|
||||||
|
animate={{ x: 0, opacity: 1, transition: { duration: 0.2 } }}
|
||||||
|
sx={{
|
||||||
|
display: 'flex',
|
||||||
|
flexDirection: 'column',
|
||||||
|
minWidth: '40%',
|
||||||
|
maxHeight: window.innerHeight - 300,
|
||||||
|
overflow: 'scroll',
|
||||||
|
p: 4,
|
||||||
|
gap: 4,
|
||||||
|
borderRadius: 4,
|
||||||
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.800',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
|
<Text size="xl" fontWeight={600}>
|
||||||
|
{isCheckpoint || advancedAddMode === 'checkpoint'
|
||||||
|
? 'Add Checkpoint Model'
|
||||||
|
: 'Add Diffusers Model'}
|
||||||
|
</Text>
|
||||||
|
<IAIIconButton
|
||||||
|
icon={<FaTimes />}
|
||||||
|
aria-label="Close Advanced"
|
||||||
|
onClick={() => dispatch(setAdvancedAddScanModel(null))}
|
||||||
|
size="sm"
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<IAIMantineSelect
|
||||||
|
label="Model Type"
|
||||||
|
value={advancedAddMode}
|
||||||
|
data={advancedAddModeData}
|
||||||
|
onChange={(v) => {
|
||||||
|
if (!v) return;
|
||||||
|
setAdvancedAddMode(v as ManualAddMode);
|
||||||
|
if (v === 'checkpoint') {
|
||||||
|
setIsCheckpoint(true);
|
||||||
|
} else {
|
||||||
|
setIsCheckpoint(false);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{isCheckpoint ? (
|
||||||
|
<AdvancedAddCheckpoint
|
||||||
|
key={advancedAddScanModel}
|
||||||
|
model_path={advancedAddScanModel}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<AdvancedAddDiffusers
|
||||||
|
key={advancedAddScanModel}
|
||||||
|
model_path={advancedAddScanModel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,25 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import FoundModelsList from './FoundModelsList';
|
||||||
|
import ScanAdvancedAddModels from './ScanAdvancedAddModels';
|
||||||
|
import SearchFolderForm from './SearchFolderForm';
|
||||||
|
|
||||||
|
export default function ScanModels() {
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" w="100%" gap={4}>
|
||||||
|
<SearchFolderForm />
|
||||||
|
<Flex gap={4}>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
maxHeight: window.innerHeight - 300,
|
||||||
|
overflow: 'scroll',
|
||||||
|
gap: 4,
|
||||||
|
w: '100%',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<FoundModelsList />
|
||||||
|
</Flex>
|
||||||
|
<ScanAdvancedAddModels />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,139 @@
|
|||||||
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { useForm } from '@mantine/form';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaSearch, FaSync, FaTrash } from 'react-icons/fa';
|
||||||
|
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
|
||||||
|
import {
|
||||||
|
setAdvancedAddScanModel,
|
||||||
|
setSearchFolder,
|
||||||
|
} from '../../store/modelManagerSlice';
|
||||||
|
|
||||||
|
type SearchFolderForm = {
|
||||||
|
folder: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
function SearchFolderForm() {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const searchFolder = useAppSelector(
|
||||||
|
(state: RootState) => state.modelmanager.searchFolder
|
||||||
|
);
|
||||||
|
|
||||||
|
const { refetch: refetchFoundModels } = useGetModelsInFolderQuery({
|
||||||
|
search_path: searchFolder ? searchFolder : '',
|
||||||
|
});
|
||||||
|
|
||||||
|
const searchFolderForm = useForm<SearchFolderForm>({
|
||||||
|
initialValues: {
|
||||||
|
folder: '',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const searchFolderFormSubmitHandler = useCallback(
|
||||||
|
(values: SearchFolderForm) => {
|
||||||
|
dispatch(setSearchFolder(values.folder));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const scanAgainHandler = () => {
|
||||||
|
refetchFoundModels();
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form
|
||||||
|
onSubmit={searchFolderForm.onSubmit((values) =>
|
||||||
|
searchFolderFormSubmitHandler(values)
|
||||||
|
)}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
w: '100%',
|
||||||
|
gap: 2,
|
||||||
|
borderRadius: 4,
|
||||||
|
alignItems: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex w="100%" alignItems="center" gap={4} minH={12}>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: 'sm',
|
||||||
|
fontWeight: 600,
|
||||||
|
color: 'base.700',
|
||||||
|
minW: 'max-content',
|
||||||
|
_dark: { color: 'base.300' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Folder
|
||||||
|
</Text>
|
||||||
|
{!searchFolder ? (
|
||||||
|
<IAIInput
|
||||||
|
w="100%"
|
||||||
|
size="md"
|
||||||
|
{...searchFolderForm.getInputProps('folder')}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
w: '100%',
|
||||||
|
p: 2,
|
||||||
|
px: 4,
|
||||||
|
bg: 'base.300',
|
||||||
|
borderRadius: 4,
|
||||||
|
fontSize: 'sm',
|
||||||
|
fontWeight: 'bold',
|
||||||
|
_dark: { bg: 'base.700' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{searchFolder}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex gap={2}>
|
||||||
|
{!searchFolder ? (
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label={t('modelManager.findModels')}
|
||||||
|
tooltip={t('modelManager.findModels')}
|
||||||
|
icon={<FaSearch />}
|
||||||
|
fontSize={18}
|
||||||
|
size="sm"
|
||||||
|
type="submit"
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label={t('modelManager.scanAgain')}
|
||||||
|
tooltip={t('modelManager.scanAgain')}
|
||||||
|
icon={<FaSync />}
|
||||||
|
onClick={scanAgainHandler}
|
||||||
|
fontSize={18}
|
||||||
|
size="sm"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||||
|
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
size="sm"
|
||||||
|
onClick={() => {
|
||||||
|
dispatch(setSearchFolder(null));
|
||||||
|
dispatch(setAdvancedAddScanModel(null));
|
||||||
|
}}
|
||||||
|
isDisabled={!searchFolder}
|
||||||
|
colorScheme="red"
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(SearchFolderForm);
|
@ -0,0 +1,108 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
// import { addNewModel } from 'app/socketio/actions';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useForm } from '@mantine/form';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
const predictionSelectData: SelectItem[] = [
|
||||||
|
{ label: 'None', value: 'none' },
|
||||||
|
{ label: 'v_prediction', value: 'v_prediction' },
|
||||||
|
{ label: 'epsilon', value: 'epsilon' },
|
||||||
|
{ label: 'sample', value: 'sample' },
|
||||||
|
];
|
||||||
|
|
||||||
|
type ExtendedImportModelConfig = {
|
||||||
|
location: string;
|
||||||
|
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function SimpleAddModels() {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const isProcessing = useAppSelector(
|
||||||
|
(state: RootState) => state.system.isProcessing
|
||||||
|
);
|
||||||
|
|
||||||
|
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||||
|
|
||||||
|
const addModelForm = useForm<ExtendedImportModelConfig>({
|
||||||
|
initialValues: {
|
||||||
|
location: '',
|
||||||
|
prediction_type: undefined,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
|
||||||
|
const importModelResponseBody = {
|
||||||
|
location: values.location,
|
||||||
|
prediction_type:
|
||||||
|
values.prediction_type === 'none' ? undefined : values.prediction_type,
|
||||||
|
};
|
||||||
|
|
||||||
|
importMainModel({ body: importModelResponseBody })
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Model Added',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
addModelForm.reset();
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
console.log(error);
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${error.data.detail} `,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form
|
||||||
|
onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" width="100%" gap={4}>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Model Location"
|
||||||
|
placeholder="Provide a path to a local Diffusers model, local checkpoint / safetensors model or a HuggingFace Repo ID"
|
||||||
|
w="100%"
|
||||||
|
{...addModelForm.getInputProps('location')}
|
||||||
|
/>
|
||||||
|
<IAIMantineSelect
|
||||||
|
label="Prediction Type (for Stable Diffusion 2.x Models only)"
|
||||||
|
data={predictionSelectData}
|
||||||
|
defaultValue="none"
|
||||||
|
{...addModelForm.getInputProps('prediction_type')}
|
||||||
|
/>
|
||||||
|
<IAIButton
|
||||||
|
type="submit"
|
||||||
|
isLoading={isLoading}
|
||||||
|
isDisabled={isLoading || isProcessing}
|
||||||
|
>
|
||||||
|
{t('modelManager.addModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,39 @@
|
|||||||
|
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import AddModels from './AddModelsPanel/AddModels';
|
||||||
|
import ScanModels from './AddModelsPanel/ScanModels';
|
||||||
|
|
||||||
|
type AddModelTabs = 'add' | 'scan';
|
||||||
|
|
||||||
|
export default function ImportModelsPanel() {
|
||||||
|
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" gap={4}>
|
||||||
|
<ButtonGroup isAttached>
|
||||||
|
<IAIButton
|
||||||
|
onClick={() => setAddModelTab('add')}
|
||||||
|
isChecked={addModelTab == 'add'}
|
||||||
|
size="sm"
|
||||||
|
width="100%"
|
||||||
|
>
|
||||||
|
{t('modelManager.addModel')}
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton
|
||||||
|
onClick={() => setAddModelTab('scan')}
|
||||||
|
isChecked={addModelTab == 'scan'}
|
||||||
|
size="sm"
|
||||||
|
width="100%"
|
||||||
|
>
|
||||||
|
{t('modelManager.scanForModels')}
|
||||||
|
</IAIButton>
|
||||||
|
</ButtonGroup>
|
||||||
|
|
||||||
|
{addModelTab == 'add' && <AddModels />}
|
||||||
|
{addModelTab == 'scan' && <ScanModels />}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -1,11 +1,4 @@
|
|||||||
import {
|
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||||
Flex,
|
|
||||||
Radio,
|
|
||||||
RadioGroup,
|
|
||||||
Text,
|
|
||||||
Tooltip,
|
|
||||||
useColorMode,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { makeToast } from 'app/components/Toaster';
|
import { makeToast } from 'app/components/Toaster';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
@ -23,7 +16,6 @@ import {
|
|||||||
useMergeMainModelsMutation,
|
useMergeMainModelsMutation,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import { BaseModelType, MergeModelConfig } from 'services/api/types';
|
import { BaseModelType, MergeModelConfig } from 'services/api/types';
|
||||||
import { mode } from 'theme/util/mode';
|
|
||||||
|
|
||||||
const baseModelTypeSelectData = [
|
const baseModelTypeSelectData = [
|
||||||
{ label: 'Stable Diffusion 1', value: 'sd-1' },
|
{ label: 'Stable Diffusion 1', value: 'sd-1' },
|
||||||
@ -38,13 +30,9 @@ type MergeInterpolationMethods =
|
|||||||
|
|
||||||
export default function MergeModelsPanel() {
|
export default function MergeModelsPanel() {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { colorMode } = useColorMode();
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useGetMainModelsQuery({
|
const { data } = useGetMainModelsQuery();
|
||||||
model_type: 'main',
|
|
||||||
base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
|
||||||
});
|
|
||||||
|
|
||||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
||||||
|
|
||||||
@ -70,10 +58,10 @@ export default function MergeModelsPanel() {
|
|||||||
}, [sd1DiffusersModels, sd2DiffusersModels]);
|
}, [sd1DiffusersModels, sd2DiffusersModels]);
|
||||||
|
|
||||||
const [modelOne, setModelOne] = useState<string | null>(
|
const [modelOne, setModelOne] = useState<string | null>(
|
||||||
Object.keys(modelsMap[baseModel])[0]
|
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[0]
|
||||||
);
|
);
|
||||||
const [modelTwo, setModelTwo] = useState<string | null>(
|
const [modelTwo, setModelTwo] = useState<string | null>(
|
||||||
Object.keys(modelsMap[baseModel])[1]
|
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[1]
|
||||||
);
|
);
|
||||||
|
|
||||||
const [modelThree, setModelThree] = useState<string | null>(null);
|
const [modelThree, setModelThree] = useState<string | null>(null);
|
||||||
@ -101,9 +89,9 @@ export default function MergeModelsPanel() {
|
|||||||
modelsMap[baseModel as keyof typeof modelsMap]
|
modelsMap[baseModel as keyof typeof modelsMap]
|
||||||
).filter((model) => model !== modelOne && model !== modelThree);
|
).filter((model) => model !== modelOne && model !== modelThree);
|
||||||
|
|
||||||
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
|
const modelThreeList = Object.keys(
|
||||||
(model) => model !== modelOne && model !== modelTwo
|
modelsMap[baseModel as keyof typeof modelsMap]
|
||||||
);
|
).filter((model) => model !== modelOne && model !== modelTwo);
|
||||||
|
|
||||||
const handleBaseModelChange = (v: string) => {
|
const handleBaseModelChange = (v: string) => {
|
||||||
setBaseModel(v as BaseModelType);
|
setBaseModel(v as BaseModelType);
|
||||||
@ -128,9 +116,9 @@ export default function MergeModelsPanel() {
|
|||||||
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
||||||
alpha: modelMergeAlpha,
|
alpha: modelMergeAlpha,
|
||||||
interp: modelMergeInterp,
|
interp: modelMergeInterp,
|
||||||
// model_merge_save_path:
|
|
||||||
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
|
||||||
force: modelMergeForce,
|
force: modelMergeForce,
|
||||||
|
merge_dest_directory:
|
||||||
|
modelMergeSaveLocType === 'root' ? undefined : modelMergeCustomSaveLoc,
|
||||||
};
|
};
|
||||||
|
|
||||||
mergeModels({
|
mergeModels({
|
||||||
@ -230,7 +218,10 @@ export default function MergeModelsPanel() {
|
|||||||
padding: 4,
|
padding: 4,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
gap: 4,
|
gap: 4,
|
||||||
bg: mode('base.100', 'base.800')(colorMode),
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.800',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
@ -255,7 +246,10 @@ export default function MergeModelsPanel() {
|
|||||||
padding: 4,
|
padding: 4,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
gap: 4,
|
gap: 4,
|
||||||
bg: mode('base.100', 'base.800')(colorMode),
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.800',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
||||||
@ -291,13 +285,16 @@ export default function MergeModelsPanel() {
|
|||||||
</RadioGroup>
|
</RadioGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
{/* <Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
padding: 4,
|
padding: 4,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
gap: 4,
|
gap: 4,
|
||||||
bg: 'base.900',
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.900',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Flex columnGap={4}>
|
<Flex columnGap={4}>
|
||||||
@ -327,7 +324,7 @@ export default function MergeModelsPanel() {
|
|||||||
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex> */}
|
</Flex>
|
||||||
|
|
||||||
<IAISimpleCheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('modelManager.ignoreMismatch')}
|
label={t('modelManager.ignoreMismatch')}
|
||||||
|
@ -1,33 +1,26 @@
|
|||||||
import { Divider, Flex, Text } from '@chakra-ui/react';
|
import { Badge, Divider, Flex, Text } from '@chakra-ui/react';
|
||||||
import { useForm } from '@mantine/form';
|
import { useForm } from '@mantine/form';
|
||||||
import { makeToast } from 'app/components/Toaster';
|
import { makeToast } from 'app/components/Toaster';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback, useEffect, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
CheckpointModelConfigEntity,
|
CheckpointModelConfigEntity,
|
||||||
|
useGetCheckpointConfigsQuery,
|
||||||
useUpdateMainModelsMutation,
|
useUpdateMainModelsMutation,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import { CheckpointModelConfig } from 'services/api/types';
|
import { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
|
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
|
||||||
|
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||||
import ModelConvert from './ModelConvert';
|
import ModelConvert from './ModelConvert';
|
||||||
|
|
||||||
const baseModelSelectData = [
|
|
||||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
|
||||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
|
||||||
];
|
|
||||||
|
|
||||||
const variantSelectData = [
|
|
||||||
{ value: 'normal', label: 'Normal' },
|
|
||||||
{ value: 'inpaint', label: 'Inpaint' },
|
|
||||||
{ value: 'depth', label: 'Depth' },
|
|
||||||
];
|
|
||||||
|
|
||||||
type CheckpointModelEditProps = {
|
type CheckpointModelEditProps = {
|
||||||
model: CheckpointModelConfigEntity;
|
model: CheckpointModelConfigEntity;
|
||||||
};
|
};
|
||||||
@ -38,6 +31,15 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
|||||||
const { model } = props;
|
const { model } = props;
|
||||||
|
|
||||||
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
|
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
|
||||||
|
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
|
||||||
|
|
||||||
|
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!availableCheckpointConfigs?.includes(model.config)) {
|
||||||
|
setUseCustomConfig(true);
|
||||||
|
}
|
||||||
|
}, [availableCheckpointConfigs, model.config]);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -80,7 +82,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((_) => {
|
||||||
checkpointEditForm.reset();
|
checkpointEditForm.reset();
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
@ -113,7 +115,20 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
|||||||
{MODEL_TYPE_MAP[model.base_model]} Model
|
{MODEL_TYPE_MAP[model.base_model]} Model
|
||||||
</Text>
|
</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
<ModelConvert model={model} />
|
{!['sdxl', 'sdxl-refiner'].includes(model.base_model) ? (
|
||||||
|
<ModelConvert model={model} />
|
||||||
|
) : (
|
||||||
|
<Badge
|
||||||
|
sx={{
|
||||||
|
p: 2,
|
||||||
|
borderRadius: 4,
|
||||||
|
bg: 'error.200',
|
||||||
|
_dark: { bg: 'error.400' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Conversion Not Supported
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Divider />
|
<Divider />
|
||||||
|
|
||||||
@ -128,21 +143,24 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label={t('modelManager.name')}
|
||||||
|
{...checkpointEditForm.getInputProps('model_name')}
|
||||||
|
/>
|
||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
label={t('modelManager.description')}
|
label={t('modelManager.description')}
|
||||||
{...checkpointEditForm.getInputProps('description')}
|
{...checkpointEditForm.getInputProps('description')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineSelect
|
<BaseModelSelect
|
||||||
label={t('modelManager.baseModel')}
|
required
|
||||||
data={baseModelSelectData}
|
|
||||||
{...checkpointEditForm.getInputProps('base_model')}
|
{...checkpointEditForm.getInputProps('base_model')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineSelect
|
<ModelVariantSelect
|
||||||
label={t('modelManager.variant')}
|
required
|
||||||
data={variantSelectData}
|
|
||||||
{...checkpointEditForm.getInputProps('variant')}
|
{...checkpointEditForm.getInputProps('variant')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
|
required
|
||||||
label={t('modelManager.modelLocation')}
|
label={t('modelManager.modelLocation')}
|
||||||
{...checkpointEditForm.getInputProps('path')}
|
{...checkpointEditForm.getInputProps('path')}
|
||||||
/>
|
/>
|
||||||
@ -150,10 +168,27 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
|||||||
label={t('modelManager.vaeLocation')}
|
label={t('modelManager.vaeLocation')}
|
||||||
{...checkpointEditForm.getInputProps('vae')}
|
{...checkpointEditForm.getInputProps('vae')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineTextInput
|
|
||||||
label={t('modelManager.config')}
|
<Flex flexDirection="column" gap={2}>
|
||||||
{...checkpointEditForm.getInputProps('config')}
|
{!useCustomConfig ? (
|
||||||
/>
|
<CheckpointConfigsSelect
|
||||||
|
required
|
||||||
|
{...checkpointEditForm.getInputProps('config')}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IAIMantineTextInput
|
||||||
|
required
|
||||||
|
label={t('modelManager.config')}
|
||||||
|
{...checkpointEditForm.getInputProps('config')}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<IAISimpleCheckbox
|
||||||
|
isChecked={useCustomConfig}
|
||||||
|
onChange={() => setUseCustomConfig(!useCustomConfig)}
|
||||||
|
label="Use Custom Config"
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
<IAIButton
|
<IAIButton
|
||||||
type="submit"
|
type="submit"
|
||||||
isDisabled={isBusy || isLoading}
|
isDisabled={isBusy || isLoading}
|
||||||
|
@ -4,7 +4,6 @@ import { makeToast } from 'app/components/Toaster';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
@ -15,22 +14,13 @@ import {
|
|||||||
useUpdateMainModelsMutation,
|
useUpdateMainModelsMutation,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import { DiffusersModelConfig } from 'services/api/types';
|
import { DiffusersModelConfig } from 'services/api/types';
|
||||||
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
|
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||||
|
|
||||||
type DiffusersModelEditProps = {
|
type DiffusersModelEditProps = {
|
||||||
model: DiffusersModelConfigEntity;
|
model: DiffusersModelConfigEntity;
|
||||||
};
|
};
|
||||||
|
|
||||||
const baseModelSelectData = [
|
|
||||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
|
||||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
|
||||||
];
|
|
||||||
|
|
||||||
const variantSelectData = [
|
|
||||||
{ value: 'normal', label: 'Normal' },
|
|
||||||
{ value: 'inpaint', label: 'Inpaint' },
|
|
||||||
{ value: 'depth', label: 'Depth' },
|
|
||||||
];
|
|
||||||
|
|
||||||
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
@ -65,6 +55,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
|||||||
model_name: model.model_name,
|
model_name: model.model_name,
|
||||||
body: values,
|
body: values,
|
||||||
};
|
};
|
||||||
|
|
||||||
updateMainModel(responseBody)
|
updateMainModel(responseBody)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((payload) => {
|
.then((payload) => {
|
||||||
@ -78,7 +69,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((_) => {
|
||||||
diffusersEditForm.reset();
|
diffusersEditForm.reset();
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
@ -118,21 +109,24 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label={t('modelManager.name')}
|
||||||
|
{...diffusersEditForm.getInputProps('model_name')}
|
||||||
|
/>
|
||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
label={t('modelManager.description')}
|
label={t('modelManager.description')}
|
||||||
{...diffusersEditForm.getInputProps('description')}
|
{...diffusersEditForm.getInputProps('description')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineSelect
|
<BaseModelSelect
|
||||||
label={t('modelManager.baseModel')}
|
required
|
||||||
data={baseModelSelectData}
|
|
||||||
{...diffusersEditForm.getInputProps('base_model')}
|
{...diffusersEditForm.getInputProps('base_model')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineSelect
|
<ModelVariantSelect
|
||||||
label={t('modelManager.variant')}
|
required
|
||||||
data={variantSelectData}
|
|
||||||
{...diffusersEditForm.getInputProps('variant')}
|
{...diffusersEditForm.getInputProps('variant')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
|
required
|
||||||
label={t('modelManager.modelLocation')}
|
label={t('modelManager.modelLocation')}
|
||||||
{...diffusersEditForm.getInputProps('path')}
|
{...diffusersEditForm.getInputProps('path')}
|
||||||
/>
|
/>
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user