Merge branch 'main' into sdxl-support

This commit is contained in:
blessedcoolant 2023-07-18 13:34:07 +12:00
commit 13da881953
100 changed files with 3322 additions and 10176 deletions

View File

@ -132,8 +132,10 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals) ### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are You must have Python 3.9 or 3.10 installed on your machine. Earlier or
not supported. later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)
1. Open a command-line window on your machine. The PowerShell is recommended for Windows. 1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space: 2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
@ -197,11 +199,18 @@ not supported.
7. Launch the web server (do it every time you run InvokeAI): 7. Launch the web server (do it every time you run InvokeAI):
```terminal ```terminal
invokeai --web invokeai-web
``` ```
8. Point your browser to http://localhost:9090 to bring up the web interface. 8. Build Node.js assets
9. Type `banana sushi` in the box on the top left and click `Invoke`.
```terminal
cd invokeai/frontend/web/
yarn vite build
```
9. Point your browser to http://localhost:9090 to bring up the web interface.
10. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI, Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`. using `source .venv/bin/activate` or `.venv\Scripts\activate`.

View File

@ -81,3 +81,193 @@ pytest --cov; open ./coverage/html/index.html
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.--> <!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md" --8<-- "invokeai/frontend/web/README.md"
## Developing InvokeAI in VSCode
VSCode offers some nice tools:
- python debugger
- automatic `venv` activation
- remote dev (e.g. run InvokeAI on a beefy linux desktop while you type in
comfort on your macbook)
### Setup
You'll need the
[Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python)
and
[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance)
extensions installed first.
It's also really handy to install the `Jupyter` extensions:
- [Jupyter](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)
- [Jupyter Cell Tags](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-cell-tags)
- [Jupyter Notebook Renderers](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter-renderers)
- [Jupyter Slide Show](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-slideshow)
#### InvokeAI workspace
Creating a VSCode workspace for working on InvokeAI is highly recommended. It
can hold InvokeAI-specific settings and configs.
To make a workspace:
- Open the InvokeAI repo dir in VSCode
- `File` > `Save Workspace As` > save it _outside_ the repo
#### Default python interpreter (i.e. automatic virtual environment activation)
- Use command palette to run command
`Preferences: Open Workspace Settings (JSON)`
- Add `python.defaultInterpreterPath` to `settings`, pointing to your `venv`'s
python
Should look something like this:
```json
{
// I like to have all InvokeAI-related folders in my workspace
"folders": [
{
// repo root
"path": "InvokeAI"
},
{
// InvokeAI root dir, where `invokeai.yaml` lives
"path": "/path/to/invokeai_root"
}
],
"settings": {
// Where your InvokeAI `venv`'s python executable lives
"python.defaultInterpreterPath": "/path/to/invokeai_root/.venv/bin/python"
}
}
```
Now when you open the VSCode integrated terminal, or do anything that needs to
run python, it will automatically be in your InvokeAI virtual environment.
Bonus: When you create a Jupyter notebook, when you run it, you'll be prompted
for the python interpreter to run in. This will default to your `venv` python,
and so you'll have access to the same python environment as the InvokeAI app.
This is _super_ handy.
#### Debugging configs with `launch.json`
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
these can be scoped to a workspace or folder.
Follow the [official guide](https://code.visualstudio.com/docs/python/debugging)
to set up your `launch.json` and try it out.
Now we can create the InvokeAI debugging configs:
```json
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
// Run the InvokeAI backend & serve the pre-built UI
"name": "InvokeAI Web",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-web.py",
"args": [
// Your InvokeAI root dir (where `invokeai.yaml` lives)
"--root",
"/path/to/invokeai_root",
// Access the app from anywhere on your local network
"--host",
"0.0.0.0"
],
"justMyCode": true
},
{
// Run the nodes-based CLI
"name": "InvokeAI CLI",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-cli.py",
"justMyCode": true
},
{
// Run tests
"name": "InvokeAI Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["--capture=no"],
"justMyCode": true
},
{
// Run a single test
"name": "InvokeAI Single Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": [
// Change this to point to the specific test you are working on
"tests/nodes/test_invoker.py"
],
"justMyCode": true
},
{
// This is the default, useful to just run a single file
"name": "Python: File",
"type": "python",
"request": "launch",
"program": "${file}",
"justMyCode": true
}
]
}
```
You'll see these configs in the debugging configs drop down. Running them will
start InvokeAI with attached debugger, in the correct environment, and work just
like the normal app.
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).
#### Remote dev
This is very easy to set up and provides the same very smooth experience as
local development. Environments and debugging, as set up above, just work,
though you'd need to recreate the workspace and debugging configs on the remote.
Consult the
[official guide](https://code.visualstudio.com/docs/remote/remote-overview) to
get it set up.
Suggest using VSCode's included settings sync so that your remote dev host has
all the same app settings and extensions automagically.
##### One remote dev gotcha
I've found the automatic port forwarding to be very flakey. You can disable it
in `Preferences: Open Remote Settings (ssh: hostname)`. Search for
`remote.autoForwardPorts` and untick the box.
To forward ports very reliably, use SSH on the remote dev client (e.g. your
macbook). Here's how to forward both backend API port (`9090`) and the frontend
live dev server port (`5173`):
```bash
ssh \
-L 9090:localhost:9090 \
-L 5173:localhost:5173 \
user@remote-dev-host
```
The forwarding stops when you close the terminal window, so suggest to do this
_outside_ the VSCode integrated terminal in case you need to restart VSCode for
an extension update or something
Now, on your remote dev client, you can open `localhost:9090` and access the UI,
now served from the remote dev host, just the same as if it was running on the
client.

View File

@ -1,4 +1,8 @@
# Nodes Editor (Experimental Beta) # Nodes Editor (Experimental)
🚨
*The node editor is experimental. We've made it accessible because we use it to develop the application, but we have not addressed the many known rough edges. It's very easy to shoot yourself in the foot, and we cannot offer support for it until it sees full release (ETA v3.1). Everything is subject to change without warning.*
🚨
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node. The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.

View File

@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
) )
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
@ -20,7 +21,6 @@ from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_queue import MemoryInvocationQueue
@ -57,8 +57,8 @@ class ApiDependencies:
invoker: Invoker = None invoker: Invoker = None
@staticmethod @staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger): def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.debug(f'InvokeAI version {__version__}') logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}") logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id)
@ -117,7 +117,7 @@ class ApiDependencies:
) )
services = InvocationServices( services = InvocationServices(
model_manager=ModelManagerService(config,logger), model_manager=ModelManagerService(config, logger),
events=events, events=events,
latents=latents, latents=latents,
images=images, images=images,
@ -129,7 +129,6 @@ class ApiDependencies:
), ),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config, configuration=config,
logger=logger, logger=logger,
) )

View File

@ -13,8 +13,10 @@ from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import ( from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType, SchedulerPredictionType,
ModelNotFoundException,
) )
from invokeai.backend.model_management import MergeInterpolationMethod from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -51,8 +53,9 @@ async def list_models(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"}, responses={200: {"description" : "The model was updated successfully"},
400: {"description" : "Bad request"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"} 409: {"description" : "There is already a model corresponding to the new name"},
}, },
status_code = 200, status_code = 200,
response_model = UpdateModelResponse, response_model = UpdateModelResponse,
@ -63,23 +66,58 @@ async def update_model(
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse: ) -> UpdateModelResponse:
""" Add Model """ """ Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger
try: try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = info.model_name,
new_base = info.base_model,
)
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get('path')
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
model_attributes=info.dict() model_attributes=info.dict()
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
) )
model_response = parse_obj_as(UpdateModelResponse, model_raw) model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return model_response return model_response
@ -126,7 +164,7 @@ async def import_model(
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
@ -166,57 +204,13 @@ async def add_model(
model_type=info.model_type model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
@ -243,9 +237,9 @@ async def delete_model(
) )
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
return Response(status_code=204) return Response(status_code=204)
except KeyError: except ModelNotFoundException as e:
logger.error(f"Model not found: {model_name}") logger.error(str(e))
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=str(e))
@models_router.put( @models_router.put(
"/convert/{base_model}/{model_type}/{model_name}", "/convert/{base_model}/{model_type}/{model_name}",
@ -278,8 +272,8 @@ async def convert_model(
base_model = base_model, base_model = base_model,
model_type = model_type) model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@ -369,8 +363,55 @@ async def merge_models(
model_type = ModelType.Main, model_type = ModelType.Main,
) )
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
# The rename operation is now supported by update_model and no longer needs to be
# a standalone route.
# @models_router.post(
# "/rename/{base_model}/{model_type}/{model_name}",
# operation_id="rename_model",
# responses= {
# 201: {"description" : "The model was renamed successfully"},
# 404: {"description" : "The model could not be found"},
# 409: {"description" : "There is already a model corresponding to the new name"},
# },
# status_code=201,
# response_model=ImportModelResponse
# )
# async def rename_model(
# base_model: BaseModelType = Path(description="Base model"),
# model_type: ModelType = Path(description="The type of model"),
# model_name: str = Path(description="current model name"),
# new_name: Optional[str] = Query(description="new model name", default=None),
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
# ) -> ImportModelResponse:
# """ Rename a model"""
# logger = ApiDependencies.invoker.services.logger
# try:
# result = ApiDependencies.invoker.services.model_manager.rename_model(
# base_model = base_model,
# model_type = model_type,
# model_name = model_name,
# new_name = new_name,
# new_base = new_base,
# )
# logger.debug(result)
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
# model_name=new_name or model_name,
# base_model=new_base or base_model,
# model_type=model_type
# )
# return parse_obj_as(ImportModelResponse, model_raw)
# except ModelNotFoundException as e:
# logger.error(str(e))
# raise HTTPException(status_code=404, detail=str(e))
# except ValueError as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))

View File

@ -39,6 +39,7 @@ from .invocations.baseinvocation import BaseInvocation
import torch import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes

View File

@ -54,10 +54,10 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
import torch import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes
@ -295,7 +295,6 @@ def invoke_cli():
), ),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger, logger=logger,
configuration=config, configuration=config,
) )

View File

@ -86,10 +86,10 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
@ -111,6 +111,7 @@ class CompelInvocation(BaseInvocation):
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:

View File

@ -157,13 +157,13 @@ class InpaintInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context,)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
with vae_info as vae,\ with vae_info as vae,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\

View File

@ -83,7 +83,7 @@ def get_scheduler(
scheduler_name, SCHEDULER_MAP['ddim'] scheduler_name, SCHEDULER_MAP['ddim']
) )
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict() **scheduler_info.dict(), context=context,
) )
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -269,6 +269,7 @@ class TextToLatentsInvocation(BaseInvocation):
model_name=control_info.control_model.model_name, model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet, model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model, base_model=control_info.control_model.base_model,
context=context,
) )
) )
@ -320,14 +321,14 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -410,14 +411,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -498,7 +499,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
with vae_info as vae: with vae_info as vae:
@ -670,7 +671,7 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))

View File

@ -1,6 +1,8 @@
from typing import Literal from os.path import exists
from typing import Literal, Optional
from pydantic.fields import Field import numpy as np
from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
@ -55,3 +57,41 @@ class DynamicPromptInvocation(BaseInvocation):
prompts = generator.generate(self.prompt, num_images=self.max_prompts) prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
class PromptsFromFileInvocation(BaseInvocation):
'''Loads prompts from a text file'''
# fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs
file_path: str = Field(description="Path to prompt text file")
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
#fmt: on
@validator("file_path")
def file_path_exists(cls, v):
if not exists(v):
raise ValueError(FileNotFoundError)
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
if max_prompts <= 0:
end_line = np.iinfo(np.int32).max
with open(file_path) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
if i >= end_line:
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -1,55 +0,0 @@
from typing import Literal, Optional
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
strength=self.strength, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,48 +1,112 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath
from typing import Literal, Optional from typing import Literal, Union, cast
import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field from pydantic import Field
from realesrgan import RealESRGANer
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput from .image import ImageOutput
# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off class RealESRGANInvocation(BaseInvocation):
type: Literal["upscale"] = "upscale" """Upscales an image using RealESRGAN."""
# Inputs type: Literal["realesrgan"] = "realesrgan"
image: Optional[ImageField] = Field(description="The input image", default=None) image: Union[ImageField, None] = Field(default=None, description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength") model_name: REALESRGAN_MODELS = Field(
level: Literal[2, 4] = Field(default=2, description="The upscale level") default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
# fmt: on )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct( models_path = context.services.configuration.models_path
image_list=[[image, 0]],
upscale=(self.level, self.strength), rrdbnet_model = None
strength=0.0, # GFPGAN strength netscale = None
save_original=False, esrgan_model_path = None
image_callback=None,
if self.model_name in [
"RealESRGAN_x4plus.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]:
# x4 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)
raise ValueError(msg)
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
) )
# Results are image and seed, unwrap for now # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: can this return multiple results? cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL
pil_image = Image.fromarray(
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=pil_image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,

View File

@ -271,13 +271,13 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(self)->List[str]: def _excluded(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed # internal fields that shouldn't be exposed as command line options
return ['type','initconf'] return ['type','initconf']
@classmethod @classmethod
def _excluded_from_yaml(self)->List[str]: def _excluded_from_yaml(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed # combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model'] return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']
class Config: class Config:
env_file_encoding = 'utf-8' env_file_encoding = 'utf-8'
@ -366,7 +366,7 @@ setting environment variables INVOKEAI_<setting>.
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features') nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features') restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')

View File

@ -105,8 +105,6 @@ class EventServiceBase:
def emit_model_load_started ( def emit_model_load_started (
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -117,8 +115,6 @@ class EventServiceBase:
event_name="model_load_started", event_name="model_load_started",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -129,8 +125,6 @@ class EventServiceBase:
def emit_model_load_completed( def emit_model_load_completed(
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -142,12 +136,12 @@ class EventServiceBase:
event_name="model_load_completed", event_name="model_load_completed",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, hash=model_info.hash,
location=model_info.location,
precision=str(model_info.precision),
), ),
) )

View File

@ -10,10 +10,9 @@ if TYPE_CHECKING:
from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC from invokeai.app.services.invoker import InvocationProcessorABC
@ -24,7 +23,7 @@ class InvocationServices:
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC" boards: "BoardServiceABC"
configuration: "InvokeAISettings" configuration: "InvokeAIAppConfig"
events: "EventServiceBase" events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"] graph_library: "ItemStorageABC"["LibraryGraph"]
@ -34,13 +33,12 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__( def __init__(
self, self,
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC", boards: "BoardServiceABC",
configuration: "InvokeAISettings", configuration: "InvokeAIAppConfig",
events: "EventServiceBase", events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"], graph_library: "ItemStorageABC"["LibraryGraph"],
@ -50,7 +48,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
restoration: "RestorationServices",
): ):
self.board_images = board_images self.board_images = board_images
self.boards = boards self.boards = boards
@ -65,4 +62,3 @@ class InvocationServices:
self.model_manager = model_manager self.model_manager = model_manager
self.processor = processor self.processor = processor
self.queue = queue self.queue = queue
self.restoration = restoration

View File

@ -18,6 +18,7 @@ from invokeai.backend.model_management import (
SchedulerPredictionType, SchedulerPredictionType,
ModelMerger, ModelMerger,
MergeInterpolationMethod, MergeInterpolationMethod,
ModelNotFoundException,
) )
from invokeai.backend.model_management.model_search import FindModels from invokeai.backend.model_management.model_search import FindModels
@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC):
) -> AddModelResult: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with a Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist. ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
@ -338,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> ModelInfo: ) -> ModelInfo:
""" """
@ -346,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# if we are called from within a node, then we get to emit # we can emit model loading events if we are executing with access to the invocation context
# load start and complete events if context:
if node and context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -365,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
submodel, submodel,
) )
if node and context: if context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -451,14 +448,14 @@ class ModelManagerService(ModelManagerServiceBase):
) -> AddModelResult: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with a Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist. ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'update model {model_name}') self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type): if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}") raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model( def del_model(
@ -509,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event( def _emit_load_event(
self, self,
node,
context, context,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: SubModelType, submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None, model_info: Optional[ModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException() raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info: if model_info:
context.services.events.emit_model_load_completed( context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -535,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
else: else:
context.services.events.emit_model_load_started( context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,

View File

@ -1,113 +0,0 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args,logger:types.ModuleType):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
# TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")
else:
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
# TO DO: refactor into separate methods
def upscale_and_reconstruct(
self,
image_list,
facetool="gfpgan",
upscale=None,
upscale_denoise_str=0.75,
strength=0.0,
codeformer_fidelity=0.75,
save_original=False,
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
CPU_DEVICE if self.device == MPS_DEVICE else self.device
)
image = self.codeformer.process(
image=image,
strength=strength,
device=cf_device,
seed=seed,
fidelity=codeformer_fidelity,
)
else:
self.logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image,
upscale[1],
seed,
int(upscale[0]),
denoise_str=upscale_denoise_str,
)
else:
self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
image_callback(image, seed, upscaled=True, use_prefix=prefix)
else:
r[0] = image
results.append([image, seed])
return results

View File

@ -30,8 +30,6 @@ from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
AutoProcessor,
CLIPSegForImageSegmentation,
CLIPTextModel, CLIPTextModel,
CLIPTokenizer, CLIPTokenizer,
AutoFeatureExtractor, AutoFeatureExtractor,
@ -45,7 +43,6 @@ from invokeai.app.services.config import (
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress, CenteredButtonPress,
IntTitleSlider, IntTitleSlider,
set_min_terminal_size, set_min_terminal_size,
@ -226,64 +223,30 @@ def download_conversion_models():
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
logger.info("Installing models from RealESRGAN...") logger.info("Installing RealESRGAN models...")
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" URLs = [
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth" dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth" dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth" description = "RealESRGAN_x4plus.pth",
),
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN") dict(
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn") url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description = "RealESRGAN_x4plus_anime_6B.pth",
def download_gfpgan(): ),
logger.info("Installing GFPGAN models...") dict(
for model in ( url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
[ dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", description = "ESRGAN_SRx4_DF2KOST_official.pth",
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth", ),
], ]
[ for model in URLs:
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], config.root_path / model[1]
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
# --------------------------------------------- # ---------------------------------------------
def download_codeformer():
logger.info("Installing CodeFormer model file...")
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
# ---------------------------------------------
def download_clipseg():
logger.info("Installing clipseg model for text-based masking...")
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try:
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
except Exception:
logger.info("Error installing clipseg model:")
logger.info(traceback.format_exc())
def download_support_models(): def download_support_models():
download_realesrgan() download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_conversion_models() download_conversion_models()
# ------------------------------------- # -------------------------------------
@ -858,9 +821,9 @@ def main():
download_support_models() download_support_models()
if opt.skip_sd_weights: if opt.skip_sd_weights:
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **") logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
elif models_to_download: elif models_to_download:
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **") logger.info("DOWNLOADING DIFFUSION WEIGHTS")
process_and_execute(opt, models_to_download) process_and_execute(opt, models_to_download)
postscript(errors=errors) postscript(errors=errors)

View File

@ -117,6 +117,7 @@ class ModelInstall(object):
# supplement with entries in models.yaml # supplement with entries in models.yaml
installed_models = self.mgr.list_models() installed_models = self.mgr.list_models()
for md in installed_models: for md in installed_models:
base = md['base_model'] base = md['base_model']
model_type = md['model_type'] model_type = md['model_type']
@ -134,6 +135,12 @@ class ModelInstall(object):
) )
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())} return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print(f'Installed models of type `{model_type}`:')
for i in installed:
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
def starter_models(self)->Set[str]: def starter_models(self)->Set[str]:
models = set() models = set()
for key, value in self.datasets.items(): for key, value in self.datasets.items():

View File

@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
from .model_merge import ModelMerger, MergeInterpolationMethod from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -552,7 +552,7 @@ class ModelManager(object):
model_config = self.models.get(model_key) model_config = self.models.get(model_key)
if not model_config: if not model_config:
self.logger.error(f'Unknown model {model_name}') self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}') raise ModelNotFoundException(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model: if base_model is not None and cur_base_model != base_model:
@ -568,6 +568,9 @@ class ModelManager(object):
model_type=cur_model_type, model_type=cur_model_type,
) )
# expose paths as absolute to help web UI
if path := model_dict.get('path'):
model_dict['path'] = str(self.app_config.root_path / path)
models.append(model_dict) models.append(model_dict)
return models return models
@ -596,7 +599,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None) model_cfg = self.models.pop(model_key, None)
if model_cfg is None: if model_cfg is None:
raise KeyError(f"Unknown model {model_key}") raise ModelNotFoundException(f"Unknown model {model_key}")
# note: it not garantie to release memory(model can has other references) # note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
@ -635,6 +638,10 @@ class ModelManager(object):
The returned dict has the same format as the dict returned by The returned dict has the same format as the dict returned by
model_info(). model_info().
""" """
# relativize paths as they go in - this makes it easier to move the root directory around
if path := model_attributes.get('path'):
if Path(path).is_relative_to(self.app_config.root_path):
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
@ -689,7 +696,7 @@ class ModelManager(object):
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None) model_cfg = self.models.get(model_key, None)
if not model_cfg: if not model_cfg:
raise KeyError(f"Unknown model: {model_key}") raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name new_name = new_name or model_name
@ -700,7 +707,7 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it # if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path): if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
move(old_path, new_path) move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
@ -908,7 +915,6 @@ class ModelManager(object):
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch): class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger) super().__init__(directories, logger)
@ -965,7 +971,7 @@ class ModelManager(object):
that model. that model.
May return the following exceptions: May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists - ValueError - a corresponding model already exists
''' '''
# avoid circular import here # avoid circular import here

View File

@ -68,7 +68,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti return None # diffusers-ti
if os.path.isfile(path): if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
return None return None
raise InvalidModelException(f"Not a valid model: {path}") raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -16,6 +16,7 @@ from .base import (
calc_model_size_by_data, calc_model_size_by_data,
classproperty, classproperty,
InvalidModelException, InvalidModelException,
ModelNotFoundException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available

View File

@ -1,4 +0,0 @@
"""
Initialization file for the invokeai.backend.restoration package
"""
from .base import Restoration

View File

@ -1,45 +0,0 @@
import invokeai.backend.util.logging as logger
class Restoration:
def __init__(self) -> None:
pass
def load_face_restore_models(
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
):
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
logger.info("GFPGAN Initialized")
else:
logger.info("GFPGAN Disabled")
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
logger.info("CodeFormer Initialized")
else:
logger.info("CodeFormer Disabled")
codeformer = None
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self, gfpgan_model_path):
from .gfpgan import GFPGAN
return GFPGAN(gfpgan_model_path)
def load_codeformer(self):
from .codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_esrgan(self, esrgan_bg_tile=400):
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
logger.info("ESRGAN Initialized")
return esrgan

View File

@ -1,120 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
class CodeFormerRestoration:
def __init__(
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists()
if not self.codeformer_model_exists:
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image
from torchvision.transforms.functional import normalize
from .codeformer_arch import CodeFormer
cf_class = CodeFormer
cf = cf_class(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
# note that this file should already be downloaded and cached at
# this point
checkpoint_path = load_file_from_url(
url=pretrained_model_url,
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
progress=True,
)
checkpoint = torch.load(checkpoint_path)["params_ema"]
cf.load_state_dict(checkpoint)
cf.eval()
image = image.convert("RGB")
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
face_helper = FaceRestoreHelper(
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
)
del output
torch.cuda.empty_cache()
except RuntimeError as error:
logger.error(f"Failed inference for CodeFormer: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
cf = None
return res

View File

@ -1,325 +0,0 @@
import math
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
from torch import Tensor, nn
from .vqgan_arch import *
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, "The input feature should be 4D tensor."
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
size
)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(
self,
dim_embd=512,
n_head=8,
n_layers=9,
codebook_size=1024,
latent_size=256,
connect_list=["32", "64", "128", "256"],
fix_modules=["quantize", "generator"],
):
super(CodeFormer, self).__init__(
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(
*[
TransformerSALayer(
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
)
for _ in range(self.n_layers)
]
)
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
)
self.channels = {
"16": 512,
"32": 256,
"64": 256,
"128": 128,
"256": 128,
"512": 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {
"512": 2,
"256": 5,
"128": 8,
"64": 11,
"32": 14,
"16": 18,
}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {
"16": 6,
"32": 9,
"64": 12,
"128": 15,
"256": 18,
"512": 21,
}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(
top_idx, shape=[x.shape[0], 16, 16, 256]
)
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w > 0:
x = self.fuse_convs_dict[f_size](
enc_feat_dict[f_size].detach(), x, w
)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View File

@ -1,84 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
return None
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd()
os.chdir(self.globals.root_dir / 'models')
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch="clean",
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
logger.error("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
logger.warning("WARNING: GFPGAN not initialized.")
logger.warning(
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
image = image.convert("RGB")
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
_, _, restored_img = self.gfpgan.enhance(
bgr_image_array,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@ -1,118 +0,0 @@
import math
from PIL import Image
import invokeai.backend.util.logging as logger
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process(
self,
extents: dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback=None,
prefix=None,
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_scheduler()
def wrapped_callback(img, seed, **kwargs):
preferred_seed = (
orig_opt.seed
if orig_opt.seed is not None and orig_opt.seed >= 0
else seed
)
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
result = self.generate.prompt2image(
opt.prompt,
seed=opt.seed or orig_opt.seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=extended_image.width,
height=extended_image.height,
init_img=extended_image,
strength=0.90,
image_callback=wrapped_callback if image_callback else None,
seam_size=opt.seam_size or 96,
seam_blur=opt.seam_blur or 16,
seam_strength=opt.seam_strength or 0.7,
seam_steps=20,
tile_size=32,
color_match=True,
force_outpaint=True, # this just stops the warning about erased regions
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents: dict,
) -> Image:
"""
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
"""
image = self.image
for direction in extents:
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
logger.info(f"extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)
return image
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
"""
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
"""
transposes = {
"right": ["ROTATE_90", "ROTATE_270"],
"bottom": ["ROTATE_180", "ROTATE_180"],
"left": ["ROTATE_270", "ROTATE_90"],
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self, image: Image, pixels: int) -> Image:
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
extended_img.paste(image, box=(0, pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel("A")
alpha.paste(0, (0, 0, extended_img.width, pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@ -1,102 +0,0 @@
import math
import warnings
from PIL import Image, ImageFilter
class Outpaint(object):
def __init__(self, image, generate):
self.image = image
self.generate = generate
def process(self, opt, old_opt, image_callback=None, prefix=None):
image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed
prompt = old_opt.prompt
def wrapped_callback(img, seed, **kwargs):
image_callback(img, seed, use_prefix=prefix, **kwargs)
return self.generate.prompt2image(
prompt,
seed=seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=opt.width,
height=opt.height,
init_img=image,
strength=0.83,
image_callback=wrapped_callback,
prefix=prefix,
)
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [
1,
2,
], "Direction (-D) must have exactly one or two arguments."
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == "left":
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == "bottom":
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height // 2 if pixels is None else int(pixels)
assert (
0 < pixels < image.height
), "Direction (-D) pixels length must be in the range 0 - image.size"
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height // 10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == "left":
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == "bottom":
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img

View File

@ -1,104 +0,0 @@
import warnings
import numpy as np
import torch
from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
scale = 4
bg_upsampler = RealESRGANer(
scale=scale,
model_path=[model_path, wdn_model_path],
model=model,
tile=self.bg_tile_size,
dni_weight=[denoise_str, 1 - denoise_str],
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(
self,
image: ImageType,
strength: float,
seed: str = None,
upsampler_scale: int = 2,
denoise_str: float = 0.75,
):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
except Exception:
import sys
import traceback
logger.error("Error loading Real-ESRGAN:")
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image
if seed is not None:
logger.info(
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
output, _ = upsampler.enhance(
bgr_image_array,
outscale=upsampler_scale,
alpha_upsampler="realesrgan",
)
# Flip the channels back to RGB
res = Image.fromarray(output[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -1,514 +0,0 @@
"""
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
"""
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
@torch.jit.script
def swish(x):
return x * torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(
-1.0 / self.codebook_size, 1.0 / self.codebook_size
)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
(z_flattened**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight**2).sum(1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(
d, 1, dim=1, largest=False
)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.codebook_size
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return (
z_q,
loss,
{
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance,
},
)
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1, 1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(
self,
codebook_size,
emb_dim,
num_hiddens,
straight_through=False,
kl_weight=5e-4,
temp_init=1.0,
):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(
num_hiddens, codebook_size, 1
) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
)
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
in_channels,
nf,
emb_dim,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
blocks = []
# initial conv
blocks.append(
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
)
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(
self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=[16],
codebook_size=1024,
emb_dim=256,
beta=0.25,
gumbel_straight_through=False,
gumbel_kl_weight=1e-8,
model_path=None,
):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if self.quantizer_type == "nearest":
self.beta = beta # 0.25
self.quantize = VectorQuantizer(
self.codebook_size, self.embed_dim, self.beta
)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight,
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_ema" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_ema"]
)
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, True),
]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n_layers, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_d" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_d"]
)
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
return self.main(x)

View File

@ -221,7 +221,7 @@ class ControlNetData:
control_mode: str = Field(default="balanced") control_mode: str = Field(default="balanced")
@dataclass(frozen=True) @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: torch.Tensor unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor text_embeddings: torch.Tensor
@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu') scheduler_device = torch.device('cpu')
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps, timesteps,
conditioning_data, conditioning_data,
noise=noise, noise=noise,
additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
**kwargs,
callback=callback,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -505,7 +502,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
@ -546,7 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
**kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -588,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
**kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -632,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = conditioning_data.text_embeddings encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else: else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.text_embeddings]) conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index] controlnet_weight = control_datum.weight[step_index]
@ -649,6 +646,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False, return_dict=False,
) )

View File

@ -237,11 +237,7 @@ class InvokeAIDiffuserComponent:
) )
return latents return latents
# methods below are called from do_diffusion_step and should be considered private to this class. def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
def _pad_conditioning(cond, target_len, encoder_attention_mask): def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
@ -266,16 +262,24 @@ class InvokeAIDiffuserComponent:
return cond, encoder_attention_mask return cond, encoder_attention_mask
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
encoder_attention_mask = None encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]: if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1]) max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
both_conditionings = torch.cat([unconditioning, conditioning]) return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
unconditioning, conditioning
)
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
@ -293,8 +297,32 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
): ):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) uncond_down_block, cond_down_block = None, None
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
# TODO: looks unused # TODO: looks unused
@ -328,6 +356,20 @@ class InvokeAIDiffuserComponent:
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext( cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning, modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map, index_map=context.cross_attention_index_map,
@ -340,6 +382,8 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs, **kwargs,
) )
@ -352,6 +396,8 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs, **kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x

View File

@ -0,0 +1,634 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel

View File

@ -64,22 +64,29 @@ sd-1/main/waifu-diffusion:
recommended: False recommended: False
sd-1/controlnet/canny: sd-1/controlnet/canny:
repo_id: lllyasviel/control_v11p_sd15_canny repo_id: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint: sd-1/controlnet/inpaint:
repo_id: lllyasviel/control_v11p_sd15_inpaint repo_id: lllyasviel/control_v11p_sd15_inpaint
sd-1/controlnet/mlsd: sd-1/controlnet/mlsd:
repo_id: lllyasviel/control_v11p_sd15_mlsd repo_id: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth: sd-1/controlnet/depth:
repo_id: lllyasviel/control_v11f1p_sd15_depth repo_id: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae: sd-1/controlnet/normal_bae:
repo_id: lllyasviel/control_v11p_sd15_normalbae repo_id: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg: sd-1/controlnet/seg:
repo_id: lllyasviel/control_v11p_sd15_seg repo_id: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart: sd-1/controlnet/lineart:
repo_id: lllyasviel/control_v11p_sd15_lineart repo_id: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime: sd-1/controlnet/lineart_anime:
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
repo_id: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble: sd-1/controlnet/scribble:
repo_id: lllyasviel/control_v11p_sd15_scribble repo_id: lllyasviel/control_v11p_sd15_scribble
recommended: False
sd-1/controlnet/softedge: sd-1/controlnet/softedge:
repo_id: lllyasviel/control_v11p_sd15_softedge repo_id: lllyasviel/control_v11p_sd15_softedge
sd-1/controlnet/shuffle: sd-1/controlnet/shuffle:
@ -90,6 +97,7 @@ sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p repo_id: lllyasviel/control_v11e_sd15_ip2p
sd-1/embedding/EasyNegative: sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True
sd-1/embedding/ahx-beta-453407d: sd-1/embedding/ahx-beta-453407d:
repo_id: sd-concepts-library/ahx-beta-453407d repo_id: sd-concepts-library/ahx-beta-453407d
sd-1/lora/LowRA: sd-1/lora/LowRA:

View File

@ -256,6 +256,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
widgets = dict() widgets = dict()
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude] model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
model_labels = [self.model_labels[x] for x in model_list] model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models)==0
if len(model_list) > 0: if len(model_list) > 0:
max_width = max([len(x) for x in model_labels]) max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
@ -280,7 +282,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
value=[ value=[
model_list.index(x) model_list.index(x)
for x in model_list for x in model_list
if self.all_models[x].installed if (show_recommended and self.all_models[x].recommended) \
or self.all_models[x].installed
], ],
max_height=len(model_list)//columns + 1, max_height=len(model_list)//columns + 1,
relx=4, relx=4,
@ -672,7 +675,9 @@ def select_and_download_models(opt: Namespace):
# pass # pass
installer = ModelInstall(config, prediction_type_helper=helper) installer = ModelInstall(config, prediction_type_helper=helper)
if opt.add or opt.delete: if opt.list_models:
installer.list_models(opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections( selections = InstallSelections(
install_models = opt.add or [], install_models = opt.add or [],
remove_models = opt.delete or [] remove_models = opt.delete or []
@ -745,7 +750,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--list-models", "--list-models",
choices=["diffusers","loras","controlnets","tis"], choices=[x.value for x in ModelType],
help="list installed models", help="list installed models",
) )
parser.add_argument( parser.add_argument(
@ -773,7 +778,7 @@ def main():
config.parse_args(invoke_args) config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config) logger = InvokeAILogger().getLogger(config=config)
if not (config.conf_path / 'models.yaml').exists(): if not config.model_conf_path.exists():
logger.info( logger.info(
"Your InvokeAI root directory is not set up. Calling invokeai-configure." "Your InvokeAI root directory is not set up. Calling invokeai-configure."
) )

View File

@ -399,6 +399,8 @@
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?", "deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"modelDeleted": "Model Deleted",
"modelDeleteFailed": "Failed to delete model",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.", "deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location", "formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.", "formMessageDiffusersModelLocationDesc": "Please enter at least one.",
@ -408,11 +410,13 @@
"convertToDiffusers": "Convert To Diffusers", "convertToDiffusers": "Convert To Diffusers",
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.", "convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.", "convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.", "convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location", "convertToDiffusersSaveLocation": "Save Location",
"noCustomLocationProvided": "No Custom Location Provided",
"convertingModelBegin": "Converting Model. Please wait.",
"v1": "v1", "v1": "v1",
"v2_base": "v2 (512px)", "v2_base": "v2 (512px)",
"v2_768": "v2 (768px)", "v2_768": "v2 (768px)",
@ -450,7 +454,8 @@
"none": "none", "none": "none",
"addDifference": "Add Difference", "addDifference": "Add Difference",
"pickModelType": "Pick Model Type", "pickModelType": "Pick Model Type",
"selectModel": "Select Model" "selectModel": "Select Model",
"importModels": "Import Models"
}, },
"parameters": { "parameters": {
"general": "General", "general": "General",
@ -572,6 +577,7 @@
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"downloadImageStarted": "Image Download Started", "downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied", "imageCopied": "Image Copied",
"problemCopyingImage": "Unable to Copy Image",
"imageLinkCopied": "Image Link Copied", "imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link", "problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded", "imageNotLoaded": "No Image Loaded",

View File

@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -177,6 +179,8 @@ addSocketConnectedListener();
addSocketDisconnectedListener(); addSocketDisconnectedListener();
addSocketSubscribedListener(); addSocketSubscribedListener();
addSocketUnsubscribedListener(); addSocketUnsubscribedListener();
addModelLoadStartedEventListener();
addModelLoadCompletedEventListener();
// Session Created // Session Created
addSessionCreatedPendingListener(); addSessionCreatedPendingListener();

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadCompleted,
socketModelLoadCompleted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadCompletedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadCompleted(action.payload));
},
});
};

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadStarted,
socketModelLoadStarted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadStartedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load started (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadStarted(action.payload));
},
});
};

View File

@ -21,6 +21,7 @@ import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
@ -49,6 +50,7 @@ const allReducers = {
dynamicPrompts: dynamicPromptsReducer, dynamicPrompts: dynamicPromptsReducer,
imageDeletion: imageDeletionReducer, imageDeletion: imageDeletionReducer,
lora: loraReducer, lora: loraReducer,
modelmanager: modelmanagerReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -67,6 +69,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'controlNet', 'controlNet',
'dynamicPrompts', 'dynamicPrompts',
'lora', 'lora',
'modelmanager',
]; ];
export const store = configureStore({ export const store = configureStore({

View File

@ -21,6 +21,7 @@ import { ImageDTO } from 'services/api/types';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable'; import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable'; import IAIDroppable from './IAIDroppable';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
type IAIDndImageProps = { type IAIDndImageProps = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
@ -96,119 +97,124 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}; };
return ( return (
<Flex <ImageContextMenu imageDTO={imageDTO}>
sx={{ {(ref) => (
width: 'full',
height: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}}
>
{imageDTO && (
<Flex <Flex
ref={ref}
sx={{ sx={{
w: 'full', width: 'full',
h: 'full', height: 'full',
position: fitContainer ? 'absolute' : 'relative',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}} }}
> >
<Image {imageDTO && (
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url} <Flex
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{ sx={{
boxSize: 16, w: 'full',
h: 'full',
position: fitContainer ? 'absolute' : 'relative',
alignItems: 'center',
justifyContent: 'center',
}}
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
</Flex>
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}} }}
/> />
</Flex> )}
</> </Flex>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} </ImageContextMenu>
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
)}
</Flex>
); );
}; };

View File

@ -8,19 +8,34 @@ import {
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { ChangeEvent, KeyboardEvent, memo, useCallback } from 'react'; import {
CSSProperties,
ChangeEvent,
KeyboardEvent,
memo,
useCallback,
} from 'react';
interface IAIInputProps extends InputProps { interface IAIInputProps extends InputProps {
label?: string; label?: string;
labelPos?: 'top' | 'side';
value?: string; value?: string;
size?: string; size?: string;
onChange?: (e: ChangeEvent<HTMLInputElement>) => void; onChange?: (e: ChangeEvent<HTMLInputElement>) => void;
formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>; formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>;
} }
const labelPosVerticalStyle: CSSProperties = {
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
gap: 10,
};
const IAIInput = (props: IAIInputProps) => { const IAIInput = (props: IAIInputProps) => {
const { const {
label = '', label = '',
labelPos = 'top',
isDisabled = false, isDisabled = false,
isInvalid, isInvalid,
formControlProps, formControlProps,
@ -51,6 +66,7 @@ const IAIInput = (props: IAIInputProps) => {
isInvalid={isInvalid} isInvalid={isInvalid}
isDisabled={isDisabled} isDisabled={isDisabled}
{...formControlProps} {...formControlProps}
style={labelPos === 'side' ? labelPosVerticalStyle : undefined}
> >
{label !== '' && <FormLabel>{label}</FormLabel>} {label !== '' && <FormLabel>{label}</FormLabel>}
<Input <Input

View File

@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
label: { label: {
color: mode(base700, base300)(colorMode), color: mode(base700, base300)(colorMode),
fontWeight: 'normal', fontWeight: 'normal',
marginBottom: 4,
}, },
})} })}
{...rest} {...rest}

View File

@ -9,14 +9,14 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = Omit<SelectProps, 'label'> & { export type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string; label?: string;
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, label, disabled, ...rest } = props; const { tooltip, inputRef, label, disabled, required, ...rest } = props;
const styles = useMantineSelectStyles(); const styles = useMantineSelectStyles();
@ -25,7 +25,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
<Select <Select
label={ label={
label ? ( label ? (
<FormControl isDisabled={disabled}> <FormControl isRequired={required} isDisabled={disabled}>
<FormLabel>{label}</FormLabel> <FormLabel>{label}</FormLabel>
</FormControl> </FormControl>
) : undefined ) : undefined

View File

@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
/** /**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime. * Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/ */
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime'); export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@ -48,6 +48,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover'; import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions'; import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
import IAICanvasUndoButton from './IAICanvasUndoButton'; import IAICanvasUndoButton from './IAICanvasUndoButton';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
export const selector = createSelector( export const selector = createSelector(
[systemSelector, canvasSelector, isStagingSelector], [systemSelector, canvasSelector, isStagingSelector],
@ -79,6 +80,7 @@ const IAICanvasToolbar = () => {
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
const { t } = useTranslation(); const { t } = useTranslation();
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
const { openUploader } = useImageUploader(); const { openUploader } = useImageUploader();
@ -136,10 +138,10 @@ const IAICanvasToolbar = () => {
handleCopyImageToClipboard(); handleCopyImageToClipboard();
}, },
{ {
enabled: () => !isStaging, enabled: () => !isStaging && isClipboardAPIAvailable,
preventDefault: true, preventDefault: true,
}, },
[canvasBaseLayer, isProcessing] [canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
); );
useHotkeys( useHotkeys(
@ -189,6 +191,9 @@ const IAICanvasToolbar = () => {
}; };
const handleCopyImageToClipboard = () => { const handleCopyImageToClipboard = () => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard()); dispatch(canvasCopiedToClipboard());
}; };
@ -256,13 +261,15 @@ const IAICanvasToolbar = () => {
onClick={handleSaveToGallery} onClick={handleSaveToGallery}
isDisabled={isStaging} isDisabled={isStaging}
/> />
<IAIIconButton {isClipboardAPIAvailable && (
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`} <IAIIconButton
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`} aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
icon={<FaCopy />} tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
onClick={handleCopyImageToClipboard} icon={<FaCopy />}
isDisabled={isStaging} onClick={handleCopyImageToClipboard}
/> isDisabled={isStaging}
/>
)}
<IAIIconButton <IAIIconButton
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`} aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`} tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}

View File

@ -1,7 +1,16 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { ButtonGroup, Flex, FlexProps, Link } from '@chakra-ui/react'; import {
ButtonGroup,
Flex,
FlexProps,
Link,
Menu,
MenuButton,
MenuItem,
MenuList,
} from '@chakra-ui/react';
// import { runESRGAN, runFacetool } from 'app/socketio/actions'; // import { runESRGAN, runFacetool } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -20,6 +29,7 @@ import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/U
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab, setActiveTab,
@ -48,6 +58,8 @@ import {
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce'; import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { menuListMotionProps } from 'theme/components/menu';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -120,6 +132,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
@ -128,7 +143,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
500 500
); );
const { currentData: image, isFetching } = useGetImageDTOQuery( const { currentData: imageDTO, isFetching } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken lastSelectedImage ?? skipToken
); );
@ -142,15 +157,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleCopyImageLink = useCallback(() => { const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => { const getImageUrl = () => {
if (!image) { if (!imageDTO) {
return; return;
} }
if (image.image_url.startsWith('http')) { if (imageDTO.image_url.startsWith('http')) {
return image.image_url; return imageDTO.image_url;
} }
return window.location.toString() + image.image_url; return window.location.toString() + imageDTO.image_url;
}; };
const url = getImageUrl(); const url = getImageUrl();
@ -174,7 +189,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
isClosable: true, isClosable: true,
}); });
}); });
}, [toaster, t, image]); }, [toaster, t, imageDTO]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata); recallAllParameters(metadata);
@ -192,31 +207,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]); }, [metadata?.seed, recallSeed]);
useHotkeys('s', handleUseSeed, [image]); useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); }, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [imageDTO]);
const handleSendToImageToImage = useCallback(() => { const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img()); dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image)); dispatch(initialImageSelected(imageDTO));
}, [dispatch, image]); }, [dispatch, imageDTO]);
useHotkeys('shift+i', handleSendToImageToImage, [image]); useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
const handleClickUpscale = useCallback(() => { const handleClickUpscale = useCallback(() => {
// selectedImage && dispatch(runESRGAN(selectedImage)); // selectedImage && dispatch(runESRGAN(selectedImage));
}, []); }, []);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
if (!image) { if (!imageDTO) {
return; return;
} }
dispatch(imageToDeleteSelected(image)); dispatch(imageToDeleteSelected(imageDTO));
}, [dispatch, image]); }, [dispatch, imageDTO]);
useHotkeys( useHotkeys(
'Shift+U', 'Shift+U',
@ -236,7 +251,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, },
[ [
isUpscalingEnabled, isUpscalingEnabled,
image, imageDTO,
isESRGANAvailable, isESRGANAvailable,
shouldDisableToolbarButtons, shouldDisableToolbarButtons,
isConnected, isConnected,
@ -268,7 +283,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
[ [
isFaceRestoreEnabled, isFaceRestoreEnabled,
image, imageDTO,
isGFPGANAvailable, isGFPGANAvailable,
shouldDisableToolbarButtons, shouldDisableToolbarButtons,
isConnected, isConnected,
@ -283,10 +298,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
); );
const handleSendToCanvas = useCallback(() => { const handleSendToCanvas = useCallback(() => {
if (!image) return; if (!imageDTO) return;
dispatch(sentImageToCanvas()); dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image)); dispatch(setInitialCanvasImage(imageDTO));
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') { if (activeTabName !== 'unifiedCanvas') {
@ -299,12 +314,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
duration: 2500, duration: 2500,
isClosable: true, isClosable: true,
}); });
}, [image, dispatch, activeTabName, toaster, t]); }, [imageDTO, dispatch, activeTabName, toaster, t]);
useHotkeys( useHotkeys(
'i', 'i',
() => { () => {
if (image) { if (imageDTO) {
handleClickShowImageDetails(); handleClickShowImageDetails();
} else { } else {
toaster({ toaster({
@ -315,13 +330,20 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}); });
} }
}, },
[image, shouldShowImageDetails, toaster] [imageDTO, shouldShowImageDetails, toaster]
); );
const handleClickProgressImagesToggle = useCallback(() => { const handleClickProgressImagesToggle = useCallback(() => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer)); dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]); }, [dispatch, shouldShowProgressInViewer]);
const handleCopyImage = useCallback(() => {
if (!imageDTO) {
return;
}
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO]);
return ( return (
<> <>
<Flex <Flex
@ -334,63 +356,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{...props} {...props}
> >
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIPopover <Menu>
triggerComponent={ <MenuButton
<IAIIconButton as={IAIIconButton}
aria-label={`${t('parameters.sendTo')}...`} aria-label={`${t('parameters.sendTo')}...`}
tooltip={`${t('parameters.sendTo')}...`} tooltip={`${t('parameters.sendTo')}...`}
isDisabled={!image} isDisabled={!imageDTO}
icon={<FaShareAlt />} icon={<FaShareAlt />}
/> />
} <MenuList motionProps={menuListMotionProps}>
> {imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}
<Flex </MenuList>
sx={{ </Menu>
flexDirection: 'column',
rowGap: 2,
}}
>
<IAIButton
size="sm"
onClick={handleSendToImageToImage}
leftIcon={<FaShare />}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</IAIButton>
{isCanvasEnabled && (
<IAIButton
size="sm"
onClick={handleSendToCanvas}
leftIcon={<FaShare />}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</IAIButton>
)}
{/* <IAIButton
size="sm"
onClick={handleCopyImage}
leftIcon={<FaCopy />}
>
{t('parameters.copyImage')}
</IAIButton> */}
<IAIButton
size="sm"
onClick={handleCopyImageLink}
leftIcon={<FaCopy />}
>
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={image?.image_url} target="_blank">
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
</Link>
</Flex>
</IAIPopover>
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
@ -443,7 +420,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton <IAIButton
isDisabled={ isDisabled={
!isGFPGANAvailable || !isGFPGANAvailable ||
!image || !imageDTO ||
!(isConnected && !isProcessing) || !(isConnected && !isProcessing) ||
!facetoolStrength !facetoolStrength
} }
@ -474,7 +451,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton <IAIButton
isDisabled={ isDisabled={
!isESRGANAvailable || !isESRGANAvailable ||
!image || !imageDTO ||
!(isConnected && !isProcessing) || !(isConnected && !isProcessing) ||
!upscalingLevel !upscalingLevel
} }

View File

@ -4,13 +4,14 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import { memo, useMemo } from 'react'; import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems'; import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
import SingleSelectionMenuItems from './SingleSelectionMenuItems'; import SingleSelectionMenuItems from './SingleSelectionMenuItems';
type Props = { type Props = {
imageDTO: ImageDTO; imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children']; children: ContextMenuProps<HTMLDivElement>['children'];
}; };
@ -31,18 +32,32 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
const { selectionCount } = useAppSelector(selector); const { selectionCount } = useAppSelector(selector);
const handleContextMenu = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.preventDefault();
}, []);
return ( return (
<ContextMenu<HTMLDivElement> <ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => ( menuButtonProps={{
<MenuList sx={{ visibility: 'visible !important' }}> bg: 'transparent',
{selectionCount === 1 ? ( _hover: { bg: 'transparent' },
<SingleSelectionMenuItems imageDTO={imageDTO} /> }}
) : ( renderMenu={() =>
<MultipleSelectionMenuItems /> imageDTO ? (
)} <MenuList
</MenuList> sx={{ visibility: 'visible !important' }}
)} motionProps={menuListMotionProps}
onContextMenu={handleContextMenu}
>
{selectionCount === 1 ? (
<SingleSelectionMenuItems imageDTO={imageDTO} />
) : (
<MultipleSelectionMenuItems />
)}
</MenuList>
) : null
}
> >
{children} {children}
</ContextMenu> </ContextMenu>

View File

@ -1,5 +1,4 @@
import { ExternalLinkIcon } from '@chakra-ui/icons'; import { Link, MenuItem } from '@chakra-ui/react';
import { MenuItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
@ -14,11 +13,21 @@ import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletio
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback, useContext, useMemo } from 'react'; import { memo, useCallback, useContext, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaFolder, FaShare, FaTrash } from 'react-icons/fa'; import {
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; FaAsterisk,
FaCopy,
FaDownload,
FaExternalLinkAlt,
FaFolder,
FaQuoteRight,
FaSeedling,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages'; import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
@ -61,6 +70,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name); const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const metadata = currentData?.metadata; const metadata = currentData?.metadata;
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
@ -130,13 +142,27 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
dispatch(imagesAddedToBatch([imageDTO.image_name])); dispatch(imagesAddedToBatch([imageDTO.image_name]));
}, [dispatch, imageDTO.image_name]); }, [dispatch, imageDTO.image_name]);
const handleCopyImage = useCallback(() => {
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO.image_url]);
return ( return (
<> <>
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}> <Link href={imageDTO.image_url} target="_blank">
{t('common.openInNewTab')} <MenuItem
</MenuItem> icon={<FaExternalLinkAlt />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
</Link>
{isClipboardAPIAvailable && (
<MenuItem icon={<FaCopy />} onClickCapture={handleCopyImage}>
{t('parameters.copyImage')}
</MenuItem>
)}
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaQuoteRight />}
onClickCapture={handleRecallPrompt} onClickCapture={handleRecallPrompt}
isDisabled={ isDisabled={
metadata?.positive_prompt === undefined && metadata?.positive_prompt === undefined &&
@ -147,14 +173,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaSeedling />}
onClickCapture={handleRecallSeed} onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined} isDisabled={metadata?.seed === undefined}
> >
{t('parameters.useSeed')} {t('parameters.useSeed')}
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<IoArrowUndoCircleOutline />} icon={<FaAsterisk />}
onClickCapture={handleUseAllParameters} onClickCapture={handleUseAllParameters}
isDisabled={!metadata} isDisabled={!metadata}
> >
@ -193,6 +219,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
Remove from Board Remove from Board
</MenuItem> </MenuItem>
)} )}
<Link download={true} href={imageDTO.image_url} target="_blank">
<MenuItem icon={<FaDownload />} w="100%">
{t('parameters.downloadImage')}
</MenuItem>
</Link>
<MenuItem <MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }} sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />} icon={<FaTrash />}

View File

@ -16,14 +16,13 @@ import {
ASSETS_CATEGORIES, ASSETS_CATEGORIES,
IMAGE_CATEGORIES, IMAGE_CATEGORIES,
IMAGE_LIMIT, IMAGE_LIMIT,
selectImagesAll,
} from 'features/gallery//store/gallerySlice'; } from 'features/gallery//store/gallerySlice';
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors'; import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
import { VirtuosoGrid } from 'react-virtuoso'; import { VirtuosoGrid } from 'react-virtuoso';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
import ImageGridItemContainer from './ImageGridItemContainer'; import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer'; import ImageGridListContainer from './ImageGridListContainer';
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
const selector = createSelector( const selector = createSelector(
[stateSelector, selectFilteredImages], [stateSelector, selectFilteredImages],
@ -180,7 +179,6 @@ const GalleryImageGrid = () => {
</Box> </Box>
); );
} }
console.log({ selectedBoardId });
if (status !== 'rejected') { if (status !== 'rejected') {
return ( return (

View File

@ -2,9 +2,12 @@ import { ButtonGroup } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa'; import { FaCode, FaExpand, FaMinus, FaPlus, FaInfo } from 'react-icons/fa';
import { useReactFlow } from 'reactflow'; import { useReactFlow } from 'reactflow';
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice'; import {
shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
} from '../store/nodesSlice';
const ViewportControls = () => { const ViewportControls = () => {
const { zoomIn, zoomOut, fitView } = useReactFlow(); const { zoomIn, zoomOut, fitView } = useReactFlow();
@ -12,6 +15,9 @@ const ViewportControls = () => {
const shouldShowGraphOverlay = useAppSelector( const shouldShowGraphOverlay = useAppSelector(
(state) => state.nodes.shouldShowGraphOverlay (state) => state.nodes.shouldShowGraphOverlay
); );
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
const handleClickedZoomIn = useCallback(() => { const handleClickedZoomIn = useCallback(() => {
zoomIn(); zoomIn();
@ -29,6 +35,10 @@ const ViewportControls = () => {
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay)); dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
}, [shouldShowGraphOverlay, dispatch]); }, [shouldShowGraphOverlay, dispatch]);
const handleClickedToggleFieldTypeLegend = useCallback(() => {
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
}, [shouldShowFieldTypeLegend, dispatch]);
return ( return (
<ButtonGroup isAttached orientation="vertical"> <ButtonGroup isAttached orientation="vertical">
<IAIIconButton <IAIIconButton
@ -52,6 +62,12 @@ const ViewportControls = () => {
aria-label="Show/Hide Graph" aria-label="Show/Hide Graph"
icon={<FaCode />} icon={<FaCode />}
/> />
<IAIIconButton
isChecked={shouldShowFieldTypeLegend}
onClick={handleClickedToggleFieldTypeLegend}
aria-label="Show/Hide Field Type Legend"
icon={<FaInfo />}
/>
</ButtonGroup> </ButtonGroup>
); );
}; };

View File

@ -9,10 +9,13 @@ const TopRightPanel = () => {
const shouldShowGraphOverlay = useAppSelector( const shouldShowGraphOverlay = useAppSelector(
(state: RootState) => state.nodes.shouldShowGraphOverlay (state: RootState) => state.nodes.shouldShowGraphOverlay
); );
const shouldShowFieldTypeLegend = useAppSelector(
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
);
return ( return (
<Panel position="top-right"> <Panel position="top-right">
<FieldTypeLegend /> {shouldShowFieldTypeLegend && <FieldTypeLegend />}
{shouldShowGraphOverlay && <NodeGraphOverlay />} {shouldShowGraphOverlay && <NodeGraphOverlay />}
</Panel> </Panel>
); );

View File

@ -32,6 +32,7 @@ export type NodesState = {
invocationTemplates: Record<string, InvocationTemplate>; invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null; connectionStartParams: OnConnectStartParams | null;
shouldShowGraphOverlay: boolean; shouldShowGraphOverlay: boolean;
shouldShowFieldTypeLegend: boolean;
editorInstance: ReactFlowInstance | undefined; editorInstance: ReactFlowInstance | undefined;
}; };
@ -42,6 +43,7 @@ export const initialNodesState: NodesState = {
invocationTemplates: {}, invocationTemplates: {},
connectionStartParams: null, connectionStartParams: null,
shouldShowGraphOverlay: false, shouldShowGraphOverlay: false,
shouldShowFieldTypeLegend: false,
editorInstance: undefined, editorInstance: undefined,
}; };
@ -125,6 +127,12 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => { shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload; state.shouldShowGraphOverlay = action.payload;
}, },
shouldShowFieldTypeLegendChanged: (
state,
action: PayloadAction<boolean>
) => {
state.shouldShowFieldTypeLegend = action.payload;
},
nodeTemplatesBuilt: ( nodeTemplatesBuilt: (
state, state,
action: PayloadAction<Record<string, InvocationTemplate>> action: PayloadAction<Record<string, InvocationTemplate>>
@ -161,6 +169,7 @@ export const {
connectionStarted, connectionStarted,
connectionEnded, connectionEnded,
shouldShowGraphOverlayChanged, shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
nodeTemplatesBuilt, nodeTemplatesBuilt,
nodeEditorReset, nodeEditorReset,
imageCollectionFieldValueChanged, imageCollectionFieldValueChanged,

View File

@ -1,11 +1,10 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next'; import { t } from 'i18next';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { imageUploaded } from 'services/api/thunks/image'; import { imageUploaded } from 'services/api/thunks/image';
import { import {
@ -44,8 +43,6 @@ export interface SystemState {
isCancelable: boolean; isCancelable: boolean;
enableImageDebugging: boolean; enableImageDebugging: boolean;
toastQueue: UseToastOptions[]; toastQueue: UseToastOptions[];
searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null;
/** /**
* The current progress image * The current progress image
*/ */
@ -79,7 +76,7 @@ export interface SystemState {
*/ */
consoleLogLevel: InvokeLogLevel; consoleLogLevel: InvokeLogLevel;
shouldLogToConsole: boolean; shouldLogToConsole: boolean;
statusTranslationKey: TFuncKey; statusTranslationKey: any;
/** /**
* When a session is canceled, its ID is stored here until a new session is created. * When a session is canceled, its ID is stored here until a new session is created.
*/ */
@ -106,8 +103,6 @@ export const initialSystemState: SystemState = {
isCancelable: true, isCancelable: true,
enableImageDebugging: false, enableImageDebugging: false,
toastQueue: [], toastQueue: [],
searchFolder: null,
foundModels: null,
progressImage: null, progressImage: null,
shouldAntialiasProgressImage: false, shouldAntialiasProgressImage: false,
sessionId: null, sessionId: null,
@ -132,7 +127,7 @@ export const systemSlice = createSlice({
setIsProcessing: (state, action: PayloadAction<boolean>) => { setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload; state.isProcessing = action.payload;
}, },
setCurrentStatus: (state, action: PayloadAction<TFuncKey>) => { setCurrentStatus: (state, action: any) => {
state.statusTranslationKey = action.payload; state.statusTranslationKey = action.payload;
}, },
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => { setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
@ -153,15 +148,6 @@ export const systemSlice = createSlice({
clearToastQueue: (state) => { clearToastQueue: (state) => {
state.toastQueue = []; state.toastQueue = [];
}, },
setSearchFolder: (state, action: PayloadAction<string | null>) => {
state.searchFolder = action.payload;
},
setFoundModels: (
state,
action: PayloadAction<InvokeAI.FoundModel[] | null>
) => {
state.foundModels = action.payload;
},
/** /**
* A cancel was scheduled * A cancel was scheduled
*/ */
@ -426,8 +412,6 @@ export const {
setEnableImageDebugging, setEnableImageDebugging,
addToast, addToast,
clearToastQueue, clearToastQueue,
setSearchFolder,
setFoundModels,
cancelScheduled, cancelScheduled,
scheduledCancelAborted, scheduledCancelAborted,
cancelTypeChanged, cancelTypeChanged,

View File

@ -1,11 +1,11 @@
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react'; import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
import i18n from 'i18n'; import i18n from 'i18n';
import { ReactNode, memo } from 'react'; import { ReactNode, memo } from 'react';
import AddModelsPanel from './subpanels/AddModelsPanel'; import ImportModelsPanel from './subpanels/ImportModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel'; import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel'; import ModelManagerPanel from './subpanels/ModelManagerPanel';
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels'; type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels';
type ModelManagerTabInfo = { type ModelManagerTabInfo = {
id: ModelManagerTabName; id: ModelManagerTabName;
@ -20,9 +20,9 @@ const tabs: ModelManagerTabInfo[] = [
content: <ModelManagerPanel />, content: <ModelManagerPanel />,
}, },
{ {
id: 'addModels', id: 'importModels',
label: i18n.t('modelManager.addModel'), label: i18n.t('modelManager.importModels'),
content: <AddModelsPanel />, content: <ImportModelsPanel />,
}, },
{ {
id: 'mergeModels', id: 'mergeModels',
@ -46,7 +46,7 @@ const ModelManagerTab = () => {
</Tab> </Tab>
))} ))}
</TabList> </TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}> <TabPanels sx={{ w: 'full', h: 'full' }}>
{tabs.map((tab) => ( {tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}> <TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content} {tab.content}

View File

@ -0,0 +1,29 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
type ModelManagerState = {
searchFolder: string | null;
advancedAddScanModel: string | null;
};
const initialModelManagerState: ModelManagerState = {
searchFolder: null,
advancedAddScanModel: null,
};
export const modelManagerSlice = createSlice({
name: 'modelmanager',
initialState: initialModelManagerState,
reducers: {
setSearchFolder: (state, action: PayloadAction<string | null>) => {
state.searchFolder = action.payload;
},
setAdvancedAddScanModel: (state, action: PayloadAction<string | null>) => {
state.advancedAddScanModel = action.payload;
},
},
});
export const { setSearchFolder, setAdvancedAddScanModel } =
modelManagerSlice.actions;
export default modelManagerSlice.reducer;

View File

@ -0,0 +1,3 @@
import { RootState } from 'app/store/store';
export const modelmanagerSelector = (state: RootState) => state.modelmanager;

View File

@ -1,43 +0,0 @@
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
export default function AddModelsPanel() {
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<Flex columnGap={4}>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
isChecked={addNewModelUIOption == 'ckpt'}
>
{t('modelManager.addCheckpointModel')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
isChecked={addNewModelUIOption == 'diffusers'}
>
{t('modelManager.addDiffuserModel')}
</IAIButton>
</Flex>
<Divider />
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
</Flex>
);
}

View File

@ -1,337 +0,0 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import React from 'react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import SearchModels from './SearchModels';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function AddCheckpointModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeModelConfigProps = {
name: '',
description: '',
config: 'configs/stable-diffusion/v1-inference.yaml',
weights: '',
vae: '',
width: 512,
height: 512,
format: 'ckpt',
default: false,
};
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(addNewModel(values));
dispatch(setAddNewModelUIOption(null));
};
const [addManually, setAddmanually] = React.useState<boolean>(false);
return (
<VStack gap={2} alignItems="flex-start">
<Flex columnGap={4}>
<IAISimpleCheckbox
isChecked={!addManually}
label={t('modelManager.scanForModels')}
onChange={() => setAddmanually(!addManually)}
/>
<IAISimpleCheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
</Flex>
{addManually ? (
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')}
</Text>
{/* Name */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="full"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Description */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Config */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Weights */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* VAE */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<HStack width="100%">
{/* Width */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Height */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
) : (
<SearchModels />
)}
</VStack>
);
}

View File

@ -1,259 +0,0 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Text,
VStack,
} from '@chakra-ui/react';
import { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeDiffusersModelConfigProps = {
name: '',
description: '',
repo_id: '',
path: '',
format: 'diffusers',
default: false,
vae: {
repo_id: '',
path: '',
},
};
const addModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToAdd = values;
if (values.path === '') delete diffusersModelToAdd.path;
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
dispatch(addNewModel(diffusersModelToAdd));
dispatch(setAddNewModelUIOption(null));
};
return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
isRequired
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
isRequired
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersModelLocationDesc')}
</Text>
{/* Path */}
<FormControl isInvalid={!!errors.path && touched.path}>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field as={IAIInput} id="path" name="path" type="text" />
{!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
/>
{!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.repoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersVAELocationDesc')}
</Text>
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeRepoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
);
}

View File

@ -0,0 +1,48 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels';
export default function AddModels() {
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>(
'simple'
);
return (
<Flex
flexDirection="column"
width="100%"
overflow="scroll"
maxHeight={window.innerHeight - 250}
gap={4}
>
<ButtonGroup isAttached>
<IAIButton
size="sm"
isChecked={addModelMode == 'simple'}
onClick={() => setAddModelMode('simple')}
>
Simple
</IAIButton>
<IAIButton
size="sm"
isChecked={addModelMode == 'advanced'}
onClick={() => setAddModelMode('advanced')}
>
Advanced
</IAIButton>
</ButtonGroup>
<Flex
sx={{
p: 4,
borderRadius: 4,
background: 'base.200',
_dark: { background: 'base.800' },
}}
>
{addModelMode === 'simple' && <SimpleAddModels />}
{addModelMode === 'advanced' && <AdvancedAddModels />}
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,143 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { addToast } from 'features/system/store/systemSlice';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type AdvancedAddCheckpointProps = {
model_path?: string;
};
export default function AdvancedAddCheckpoint(
props: AdvancedAddCheckpointProps
) {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
initialValues: {
model_name: model_path
? model_path.split('\\').splice(-1)[0].split('.')[0]
: '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'checkpoint',
error: undefined,
vae: '',
variant: 'normal',
config: 'configs\\stable-diffusion\\v1-inference.yaml',
},
});
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const advancedAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
advancedAddCheckpointForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={advancedAddCheckpointForm.onSubmit((v) =>
advancedAddCheckpointFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
label="Model Name"
required
{...advancedAddCheckpointForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...advancedAddCheckpointForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
label="Model Location"
required
{...advancedAddCheckpointForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...advancedAddCheckpointForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...advancedAddCheckpointForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...advancedAddCheckpointForm.getInputProps('variant')}
/>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
required
width="100%"
{...advancedAddCheckpointForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label="Custom Config File Location"
{...advancedAddCheckpointForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</Flex>
</form>
);
}

View File

@ -0,0 +1,113 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { addToast } from 'features/system/store/systemSlice';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type AdvancedAddDiffusersProps = {
model_path?: string;
};
export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const [addMainModel] = useAddMainModelsMutation();
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
initialValues: {
model_name: model_path ? model_path.split('\\').splice(-1)[0] : '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'diffusers',
error: undefined,
vae: '',
variant: 'normal',
},
});
const advancedAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
advancedAddDiffusersForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={advancedAddDiffusersForm.onSubmit((v) =>
advancedAddDiffusersFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
required
label="Model Name"
{...advancedAddDiffusersForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...advancedAddDiffusersForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
required
label="Model Location"
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
{...advancedAddDiffusersForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...advancedAddDiffusersForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...advancedAddDiffusersForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...advancedAddDiffusersForm.getInputProps('variant')}
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</form>
);
}

View File

@ -0,0 +1,46 @@
import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useState } from 'react';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
export const advancedAddModeData: SelectItem[] = [
{ label: 'Diffusers', value: 'diffusers' },
{ label: 'Checkpoint / Safetensors', value: 'checkpoint' },
];
export type ManualAddMode = 'diffusers' | 'checkpoint';
export default function AdvancedAddModels() {
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
return (
<Flex flexDirection="column" gap={4} width="100%">
<IAIMantineSelect
label="Model Type"
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) return;
setAdvancedAddMode(v as ManualAddMode);
}}
/>
<Flex
sx={{
p: 4,
borderRadius: 4,
bg: 'base.300',
_dark: {
bg: 'base.850',
},
}}
>
{advancedAddMode === 'diffusers' && <AdvancedAddDiffusers />}
{advancedAddMode === 'checkpoint' && <AdvancedAddCheckpoint />}
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,253 @@
import { Flex, Text } from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIScrollArea from 'common/components/IAIScrollArea';
import { addToast } from 'features/system/store/systemSlice';
import { difference, forEach, intersection, map, values } from 'lodash-es';
import { ChangeEvent, MouseEvent, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
SearchFolderResponse,
useGetMainModelsQuery,
useGetModelsInFolderQuery,
useImportMainModelsMutation,
} from 'services/api/endpoints/models';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
export default function FoundModelsList() {
const searchFolder = useAppSelector(
(state: RootState) => state.modelmanager.searchFolder
);
const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery();
// Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } =
useGetModelsInFolderQuery(
{
search_path: searchFolder ? searchFolder : '',
},
{
selectFromResult: ({ data }) => {
const installedModelValues = values(installedModels?.entities);
const installedModelPaths = map(installedModelValues, 'path');
// Only select models those that aren't already installed to Invoke
const notInstalledModels = difference(data, installedModelPaths);
const alreadyInstalled = intersection(data, installedModelPaths);
return {
foundModels: data,
alreadyInstalled: foundModelsFilter(alreadyInstalled, nameFilter),
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
};
},
}
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const quickAddHandler = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
const model_name = e.currentTarget.id.split('\\').splice(-1)[0];
importMainModel({
body: {
location: e.currentTarget.id,
},
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Added Model: ${model_name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Faile To Add Model',
status: 'error',
})
)
);
}
});
},
[dispatch, importMainModel]
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
const renderModels = ({
models,
showActions = true,
}: {
models: string[];
showActions?: boolean;
}) => {
return models.map((model) => {
return (
<Flex
sx={{
p: 4,
gap: 4,
alignItems: 'center',
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
key={model}
>
<Flex w="100%" sx={{ flexDirection: 'column', minW: '25%' }}>
<Text sx={{ fontWeight: 600 }}>
{model.split('\\').slice(-1)[0]}
</Text>
<Text
sx={{
fontSize: 'sm',
color: 'base.600',
_dark: {
color: 'base.400',
},
}}
>
{model}
</Text>
</Flex>
{showActions ? (
<Flex gap={2}>
<IAIButton
id={model}
onClick={quickAddHandler}
isLoading={isLoading}
>
Quick Add
</IAIButton>
<IAIButton
onClick={() => dispatch(setAdvancedAddScanModel(model))}
isLoading={isLoading}
>
Advanced
</IAIButton>
</Flex>
) : (
<Text
sx={{
fontWeight: 600,
p: 2,
borderRadius: 4,
color: 'accent.50',
bg: 'accent.400',
_dark: {
color: 'accent.100',
bg: 'accent.600',
},
}}
>
Installed
</Text>
)}
</Flex>
);
});
};
const renderFoundModels = () => {
if (!searchFolder) return;
if (!foundModels || foundModels.length === 0) {
return (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
height: 96,
userSelect: 'none',
bg: 'base.200',
_dark: {
bg: 'base.900',
},
}}
>
<Text variant="subtext">No Models Found</Text>
</Flex>
);
}
return (
<Flex
sx={{
flexDirection: 'column',
gap: 2,
w: '100%',
minW: '50%',
}}
>
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
labelPos="side"
/>
<Flex p={2} gap={2}>
<Text sx={{ fontWeight: 600 }}>
Models Found: {foundModels.length}
</Text>
<Text
sx={{
fontWeight: 600,
color: 'accent.500',
_dark: {
color: 'accent.200',
},
}}
>
Not Installed: {filteredModels.length}
</Text>
</Flex>
<IAIScrollArea offsetScrollbars>
<Flex gap={2} flexDirection="column">
{renderModels({ models: filteredModels })}
{renderModels({ models: alreadyInstalled, showActions: false })}
</Flex>
</IAIScrollArea>
</Flex>
);
};
return renderFoundModels();
}
const foundModelsFilter = (
data: SearchFolderResponse | undefined,
nameFilter: string
) => {
const filteredModels: SearchFolderResponse = [];
forEach(data, (model) => {
if (!model) {
return;
}
if (model.includes(nameFilter)) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -0,0 +1,99 @@
import { Box, Flex, Text } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { motion } from 'framer-motion';
import { useEffect, useState } from 'react';
import { FaTimes } from 'react-icons/fa';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
import { ManualAddMode, advancedAddModeData } from './AdvancedAddModels';
export default function ScanAdvancedAddModels() {
const advancedAddScanModel = useAppSelector(
(state: RootState) => state.modelmanager.advancedAddScanModel
);
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
const [isCheckpoint, setIsCheckpoint] = useState(
advancedAddScanModel &&
['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) =>
advancedAddScanModel.endsWith(ext)
)
);
useEffect(() => {
isCheckpoint
? setAdvancedAddMode('checkpoint')
: setAdvancedAddMode('diffusers');
}, [setAdvancedAddMode, isCheckpoint]);
const dispatch = useAppDispatch();
return (
advancedAddScanModel && (
<Box
as={motion.div}
initial={{ x: -100, opacity: 0 }}
animate={{ x: 0, opacity: 1, transition: { duration: 0.2 } }}
sx={{
display: 'flex',
flexDirection: 'column',
minWidth: '40%',
maxHeight: window.innerHeight - 300,
overflow: 'scroll',
p: 4,
gap: 4,
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
>
<Flex justifyContent="space-between" alignItems="center">
<Text size="xl" fontWeight={600}>
{isCheckpoint || advancedAddMode === 'checkpoint'
? 'Add Checkpoint Model'
: 'Add Diffusers Model'}
</Text>
<IAIIconButton
icon={<FaTimes />}
aria-label="Close Advanced"
onClick={() => dispatch(setAdvancedAddScanModel(null))}
size="sm"
/>
</Flex>
<IAIMantineSelect
label="Model Type"
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) return;
setAdvancedAddMode(v as ManualAddMode);
if (v === 'checkpoint') {
setIsCheckpoint(true);
} else {
setIsCheckpoint(false);
}
}}
/>
{isCheckpoint ? (
<AdvancedAddCheckpoint
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
) : (
<AdvancedAddDiffusers
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
)}
</Box>
)
);
}

View File

@ -0,0 +1,25 @@
import { Flex } from '@chakra-ui/react';
import FoundModelsList from './FoundModelsList';
import ScanAdvancedAddModels from './ScanAdvancedAddModels';
import SearchFolderForm from './SearchFolderForm';
export default function ScanModels() {
return (
<Flex flexDirection="column" w="100%" gap={4}>
<SearchFolderForm />
<Flex gap={4}>
<Flex
sx={{
maxHeight: window.innerHeight - 300,
overflow: 'scroll',
gap: 4,
w: '100%',
}}
>
<FoundModelsList />
</Flex>
<ScanAdvancedAddModels />
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,139 @@
import { Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIInput from 'common/components/IAIInput';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSearch, FaSync, FaTrash } from 'react-icons/fa';
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
import {
setAdvancedAddScanModel,
setSearchFolder,
} from '../../store/modelManagerSlice';
type SearchFolderForm = {
folder: string;
};
function SearchFolderForm() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const searchFolder = useAppSelector(
(state: RootState) => state.modelmanager.searchFolder
);
const { refetch: refetchFoundModels } = useGetModelsInFolderQuery({
search_path: searchFolder ? searchFolder : '',
});
const searchFolderForm = useForm<SearchFolderForm>({
initialValues: {
folder: '',
},
});
const searchFolderFormSubmitHandler = useCallback(
(values: SearchFolderForm) => {
dispatch(setSearchFolder(values.folder));
},
[dispatch]
);
const scanAgainHandler = () => {
refetchFoundModels();
};
return (
<form
onSubmit={searchFolderForm.onSubmit((values) =>
searchFolderFormSubmitHandler(values)
)}
style={{ width: '100%' }}
>
<Flex
sx={{
w: '100%',
gap: 2,
borderRadius: 4,
alignItems: 'center',
}}
>
<Flex w="100%" alignItems="center" gap={4} minH={12}>
<Text
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'base.700',
minW: 'max-content',
_dark: { color: 'base.300' },
}}
>
Folder
</Text>
{!searchFolder ? (
<IAIInput
w="100%"
size="md"
{...searchFolderForm.getInputProps('folder')}
/>
) : (
<Flex
sx={{
w: '100%',
p: 2,
px: 4,
bg: 'base.300',
borderRadius: 4,
fontSize: 'sm',
fontWeight: 'bold',
_dark: { bg: 'base.700' },
}}
>
{searchFolder}
</Flex>
)}
</Flex>
<Flex gap={2}>
{!searchFolder ? (
<IAIIconButton
aria-label={t('modelManager.findModels')}
tooltip={t('modelManager.findModels')}
icon={<FaSearch />}
fontSize={18}
size="sm"
type="submit"
/>
) : (
<IAIIconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<FaSync />}
onClick={scanAgainHandler}
fontSize={18}
size="sm"
/>
)}
<IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
size="sm"
onClick={() => {
dispatch(setSearchFolder(null));
dispatch(setAdvancedAddScanModel(null));
}}
isDisabled={!searchFolder}
colorScheme="red"
/>
</Flex>
</Flex>
</form>
);
}
export default memo(SearchFolderForm);

View File

@ -0,0 +1,108 @@
import { Flex } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { SelectItem } from '@mantine/core';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { addToast } from 'features/system/store/systemSlice';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
const predictionSelectData: SelectItem[] = [
{ label: 'None', value: 'none' },
{ label: 'v_prediction', value: 'v_prediction' },
{ label: 'epsilon', value: 'epsilon' },
{ label: 'sample', value: 'sample' },
];
type ExtendedImportModelConfig = {
location: string;
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
};
export default function SimpleAddModels() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm<ExtendedImportModelConfig>({
initialValues: {
location: '',
prediction_type: undefined,
},
});
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
const importModelResponseBody = {
location: values.location,
prediction_type:
values.prediction_type === 'none' ? undefined : values.prediction_type,
};
importMainModel({ body: importModelResponseBody })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: 'Model Added',
status: 'success',
})
)
);
addModelForm.reset();
})
.catch((error) => {
if (error) {
console.log(error);
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}
style={{ width: '100%' }}
>
<Flex flexDirection="column" width="100%" gap={4}>
<IAIMantineTextInput
label="Model Location"
placeholder="Provide a path to a local Diffusers model, local checkpoint / safetensors model or a HuggingFace Repo ID"
w="100%"
{...addModelForm.getInputProps('location')}
/>
<IAIMantineSelect
label="Prediction Type (for Stable Diffusion 2.x Models only)"
data={predictionSelectData}
defaultValue="none"
{...addModelForm.getInputProps('prediction_type')}
/>
<IAIButton
type="submit"
isLoading={isLoading}
isDisabled={isLoading || isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</form>
);
}

View File

@ -0,0 +1,39 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import AddModels from './AddModelsPanel/AddModels';
import ScanModels from './AddModelsPanel/ScanModels';
type AddModelTabs = 'add' | 'scan';
export default function ImportModelsPanel() {
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setAddModelTab('add')}
isChecked={addModelTab == 'add'}
size="sm"
width="100%"
>
{t('modelManager.addModel')}
</IAIButton>
<IAIButton
onClick={() => setAddModelTab('scan')}
isChecked={addModelTab == 'scan'}
size="sm"
width="100%"
>
{t('modelManager.scanForModels')}
</IAIButton>
</ButtonGroup>
{addModelTab == 'add' && <AddModels />}
{addModelTab == 'scan' && <ScanModels />}
</Flex>
);
}

View File

@ -1,11 +1,4 @@
import { import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
Flex,
Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -23,7 +16,6 @@ import {
useMergeMainModelsMutation, useMergeMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types'; import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [ const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' }, { label: 'Stable Diffusion 1', value: 'sd-1' },
@ -38,7 +30,6 @@ type MergeInterpolationMethods =
export default function MergeModelsPanel() { export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery({ const { data } = useGetMainModelsQuery({
@ -128,9 +119,9 @@ export default function MergeModelsPanel() {
mergedModelName !== '' ? mergedModelName : models_names.join('-'), mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha, alpha: modelMergeAlpha,
interp: modelMergeInterp, interp: modelMergeInterp,
// model_merge_save_path:
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce, force: modelMergeForce,
merge_dest_directory:
modelMergeSaveLocType === 'root' ? undefined : modelMergeCustomSaveLoc,
}; };
mergeModels({ mergeModels({
@ -230,7 +221,10 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: mode('base.100', 'base.800')(colorMode), bg: 'base.200',
_dark: {
bg: 'base.800',
},
}} }}
> >
<IAISlider <IAISlider
@ -255,7 +249,10 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: mode('base.100', 'base.800')(colorMode), bg: 'base.200',
_dark: {
bg: 'base.800',
},
}} }}
> >
<Text fontWeight={500} fontSize="sm" variant="subtext"> <Text fontWeight={500} fontSize="sm" variant="subtext">
@ -291,13 +288,16 @@ export default function MergeModelsPanel() {
</RadioGroup> </RadioGroup>
</Flex> </Flex>
{/* <Flex <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: 'base.200',
_dark: {
bg: 'base.900',
},
}} }}
> >
<Flex columnGap={4}> <Flex columnGap={4}>
@ -327,7 +327,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)} onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/> />
)} )}
</Flex> */} </Flex>
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')} label={t('modelManager.ignoreMismatch')}

View File

@ -4,30 +4,23 @@ import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react'; import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
CheckpointModelConfigEntity, CheckpointModelConfigEntity,
useGetCheckpointConfigsQuery,
useUpdateMainModelsMutation, useUpdateMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types'; import { CheckpointModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
model: CheckpointModelConfigEntity; model: CheckpointModelConfigEntity;
}; };
@ -38,6 +31,15 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const { model } = props; const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
useEffect(() => {
if (!availableCheckpointConfigs?.includes(model.config)) {
setUseCustomConfig(true);
}
}, [availableCheckpointConfigs, model.config]);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -80,7 +82,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
) )
); );
}) })
.catch((error) => { .catch((_) => {
checkpointEditForm.reset(); checkpointEditForm.reset();
dispatch( dispatch(
addToast( addToast(
@ -128,21 +130,24 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput <IAIMantineTextInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')} {...checkpointEditForm.getInputProps('description')}
/> />
<IAIMantineSelect <BaseModelSelect
label={t('modelManager.baseModel')} required
data={baseModelSelectData}
{...checkpointEditForm.getInputProps('base_model')} {...checkpointEditForm.getInputProps('base_model')}
/> />
<IAIMantineSelect <ModelVariantSelect
label={t('modelManager.variant')} required
data={variantSelectData}
{...checkpointEditForm.getInputProps('variant')} {...checkpointEditForm.getInputProps('variant')}
/> />
<IAIMantineTextInput <IAIMantineTextInput
required
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')} {...checkpointEditForm.getInputProps('path')}
/> />
@ -150,10 +155,27 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
label={t('modelManager.vaeLocation')} label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')} {...checkpointEditForm.getInputProps('vae')}
/> />
<IAIMantineTextInput
label={t('modelManager.config')} <Flex flexDirection="column" gap={2}>
{...checkpointEditForm.getInputProps('config')} {!useCustomConfig ? (
/> <CheckpointConfigsSelect
required
{...checkpointEditForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
</Flex>
<IAIButton <IAIButton
type="submit" type="submit"
isDisabled={isBusy || isLoading} isDisabled={isBusy || isLoading}

View File

@ -4,7 +4,6 @@ import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
@ -15,22 +14,13 @@ import {
useUpdateMainModelsMutation, useUpdateMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types'; import { DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type DiffusersModelEditProps = { type DiffusersModelEditProps = {
model: DiffusersModelConfigEntity; model: DiffusersModelConfigEntity;
}; };
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) { export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
@ -65,6 +55,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
model_name: model.model_name, model_name: model.model_name,
body: values, body: values,
}; };
updateMainModel(responseBody) updateMainModel(responseBody)
.unwrap() .unwrap()
.then((payload) => { .then((payload) => {
@ -78,7 +69,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
) )
); );
}) })
.catch((error) => { .catch((_) => {
diffusersEditForm.reset(); diffusersEditForm.reset();
dispatch( dispatch(
addToast( addToast(
@ -118,21 +109,24 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput <IAIMantineTextInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')} {...diffusersEditForm.getInputProps('description')}
/> />
<IAIMantineSelect <BaseModelSelect
label={t('modelManager.baseModel')} required
data={baseModelSelectData}
{...diffusersEditForm.getInputProps('base_model')} {...diffusersEditForm.getInputProps('base_model')}
/> />
<IAIMantineSelect <ModelVariantSelect
label={t('modelManager.variant')} required
data={variantSelectData}
{...diffusersEditForm.getInputProps('variant')} {...diffusersEditForm.getInputProps('variant')}
/> />
<IAIMantineTextInput <IAIMantineTextInput
required
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')} {...diffusersEditForm.getInputProps('path')}
/> />

View File

@ -1,9 +1,18 @@
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react'; import {
// import { convertToDiffusers } from 'app/socketio/actions'; Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
// import { convertToDiffusers } from 'app/socketio/actions';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -15,6 +24,8 @@ interface ModelConvertProps {
model: CheckpointModelConfig; model: CheckpointModelConfig;
} }
type SaveLocation = 'InvokeAIRoot' | 'Custom';
export default function ModelConvert(props: ModelConvertProps) { export default function ModelConvert(props: ModelConvertProps) {
const { model } = props; const { model } = props;
@ -23,22 +34,51 @@ export default function ModelConvert(props: ModelConvertProps) {
const [convertModel, { isLoading }] = useConvertMainModelsMutation(); const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same'); const [saveLocation, setSaveLocation] =
useState<SaveLocation>('InvokeAIRoot');
const [customSaveLocation, setCustomSaveLocation] = useState<string>(''); const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
useEffect(() => { useEffect(() => {
setSaveLocation('same'); setSaveLocation('InvokeAIRoot');
}, [model]); }, [model]);
const modelConvertCancelHandler = () => { const modelConvertCancelHandler = () => {
setSaveLocation('same'); setSaveLocation('InvokeAIRoot');
}; };
const modelConvertHandler = () => { const modelConvertHandler = () => {
const responseBody = { const responseBody = {
base_model: model.base_model, base_model: model.base_model,
model_name: model.model_name, model_name: model.model_name,
params: {
convert_dest_directory:
saveLocation === 'Custom' ? customSaveLocation : undefined,
},
}; };
if (saveLocation === 'Custom' && customSaveLocation === '') {
dispatch(
addToast(
makeToast({
title: t('modelManager.noCustomLocationProvided'),
status: 'error',
})
)
);
return;
}
dispatch(
addToast(
makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${
model.model_name
}`,
status: 'success',
})
)
);
convertModel(responseBody) convertModel(responseBody)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
@ -94,35 +134,30 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text> <Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex> </Flex>
{/* <Flex flexDir="column" gap={4}> <Flex flexDir="column" gap={2}>
<Flex marginTop={4} flexDir="column" gap={2}> <Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600"> <Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')} {t('modelManager.convertToDiffusersSaveLocation')}
</Text> </Text>
<RadioGroup value={saveLocation} onChange={(v) => setSaveLocation(v)}> <RadioGroup
value={saveLocation}
onChange={(v) => setSaveLocation(v as SaveLocation)}
>
<Flex gap={4}> <Flex gap={4}>
<Radio value="same"> <Radio value="InvokeAIRoot">
<Tooltip label="Save converted model in the same folder">
{t('modelManager.sameFolder')}
</Tooltip>
</Radio>
<Radio value="root">
<Tooltip label="Save converted model in the InvokeAI root folder"> <Tooltip label="Save converted model in the InvokeAI root folder">
{t('modelManager.invokeRoot')} {t('modelManager.invokeRoot')}
</Tooltip> </Tooltip>
</Radio> </Radio>
<Radio value="Custom">
<Radio value="custom">
<Tooltip label="Save converted model in a custom folder"> <Tooltip label="Save converted model in a custom folder">
{t('modelManager.custom')} {t('modelManager.custom')}
</Tooltip> </Tooltip>
</Radio> </Radio>
</Flex> </Flex>
</RadioGroup> </RadioGroup>
</Flex> */} </Flex>
{saveLocation === 'Custom' && (
{/* {saveLocation === 'custom' && (
<Flex flexDirection="column" rowGap={2}> <Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext"> <Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')} {t('modelManager.customSaveLocation')}
@ -130,13 +165,13 @@ export default function ModelConvert(props: ModelConvertProps) {
<IAIInput <IAIInput
value={customSaveLocation} value={customSaveLocation}
onChange={(e) => { onChange={(e) => {
if (e.target.value !== '') setCustomSaveLocation(e.target.value);
setCustomSaveLocation(e.target.value);
}} }}
width="full" width="full"
/> />
</Flex> </Flex>
)} */} )}
</Flex>
</IAIAlertDialog> </IAIAlertDialog>
); );
} }

View File

@ -44,10 +44,6 @@ const ModelList = (props: ModelListProps) => {
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
/>
<Flex <Flex
flexDirection="column" flexDirection="column"
gap={4} gap={4}
@ -79,6 +75,12 @@ const ModelList = (props: ModelListProps) => {
</IAIButton> </IAIButton>
</ButtonGroup> </ButtonGroup>
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
labelPos="side"
/>
{['all', 'diffusers'].includes(modelFormatFilter) && {['all', 'diffusers'].includes(modelFormatFilter) &&
filteredDiffusersModels.length > 0 && ( filteredDiffusersModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}> <Flex sx={{ gap: 2, flexDir: 'column' }}>

View File

@ -1,10 +1,12 @@
import { DeleteIcon } from '@chakra-ui/icons'; import { DeleteIcon } from '@chakra-ui/icons';
import { Flex, Text, Tooltip } from '@chakra-ui/react'; import { Flex, Text, Tooltip } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
@ -21,6 +23,7 @@ type ModelListItemProps = {
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch();
const [deleteMainModel] = useDeleteMainModelsMutation(); const [deleteMainModel] = useDeleteMainModelsMutation();
const { model, isSelected, setSelectedModelId } = props; const { model, isSelected, setSelectedModelId } = props;
@ -30,9 +33,34 @@ export default function ModelListItem(props: ModelListItemProps) {
}, [model.id, setSelectedModelId]); }, [model.id, setSelectedModelId]);
const handleModelDelete = useCallback(() => { const handleModelDelete = useCallback(() => {
deleteMainModel(model); deleteMainModel(model)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleted')}: ${model.model_name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleteFailed')}: ${
model.model_name
}`,
status: 'success',
})
)
);
}
});
setSelectedModelId(undefined); setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId]); }, [deleteMainModel, model, setSelectedModelId, dispatch, t]);
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}> <Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>

View File

@ -0,0 +1,26 @@
import IAIMantineSelect, {
IAISelectDataType,
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useTranslation } from 'react-i18next';
const baseModelSelectData: IAISelectDataType[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
type BaseModelSelectProps = Omit<IAISelectProps, 'data'>;
export default function BaseModelSelect(props: BaseModelSelectProps) {
const { ...rest } = props;
const { t } = useTranslation();
return (
<IAIMantineSelect
label={t('modelManager.baseModel')}
data={baseModelSelectData}
{...rest}
/>
);
}

View File

@ -0,0 +1,22 @@
import IAIMantineSelect, {
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
type CheckpointConfigSelectProps = Omit<IAISelectProps, 'data'>;
export default function CheckpointConfigsSelect(
props: CheckpointConfigSelectProps
) {
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const { ...rest } = props;
return (
<IAIMantineSelect
label="Config File"
placeholder="Select A Config File"
data={availableCheckpointConfigs ? availableCheckpointConfigs : []}
{...rest}
/>
);
}

View File

@ -0,0 +1,26 @@
import IAIMantineSelect, {
IAISelectDataType,
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { useTranslation } from 'react-i18next';
const variantSelectData: IAISelectDataType[] = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
type VariantSelectProps = Omit<IAISelectProps, 'data'>;
export default function ModelVariantSelect(props: VariantSelectProps) {
const { ...rest } = props;
const { t } = useTranslation();
return (
<IAIMantineSelect
label={t('modelManager.variant')}
data={variantSelectData}
{...rest}
/>
);
}

View File

@ -4,6 +4,8 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors'; import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider'; import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa'; import { FaCopy } from 'react-icons/fa';
@ -11,6 +13,7 @@ import { FaCopy } from 'react-icons/fa';
export default function UnifiedCanvasCopyToClipboard() { export default function UnifiedCanvasCopyToClipboard() {
const isStaging = useAppSelector(isStagingSelector); const isStaging = useAppSelector(isStagingSelector);
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
const isProcessing = useAppSelector( const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing (state: RootState) => state.system.isProcessing
@ -25,15 +28,22 @@ export default function UnifiedCanvasCopyToClipboard() {
handleCopyImageToClipboard(); handleCopyImageToClipboard();
}, },
{ {
enabled: () => !isStaging, enabled: () => !isStaging && isClipboardAPIAvailable,
preventDefault: true, preventDefault: true,
}, },
[canvasBaseLayer, isProcessing] [canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
); );
const handleCopyImageToClipboard = () => { const handleCopyImageToClipboard = useCallback(() => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard()); dispatch(canvasCopiedToClipboard());
}; }, [dispatch, isClipboardAPIAvailable]);
if (!isClipboardAPIAvailable) {
return null;
}
return ( return (
<IAIIconButton <IAIIconButton

View File

@ -0,0 +1,52 @@
import { useAppToaster } from 'app/components/Toaster';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const useCopyImageToClipboard = () => {
const toaster = useAppToaster();
const { t } = useTranslation();
const isClipboardAPIAvailable = useMemo(() => {
return Boolean(navigator.clipboard) && Boolean(window.ClipboardItem);
}, []);
const copyImageToClipboard = useCallback(
async (image_url: string) => {
if (!isClipboardAPIAvailable) {
toaster({
title: t('toast.problemCopyingImage'),
description: "Your browser doesn't support the Clipboard API.",
status: 'error',
duration: 2500,
isClosable: true,
});
}
try {
const response = await fetch(image_url);
const blob = await response.blob();
await navigator.clipboard.write([
new ClipboardItem({
[blob.type]: blob,
}),
]);
toaster({
title: t('toast.imageCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
} catch (err) {
toaster({
title: t('toast.problemCopyingImage'),
description: String(err),
status: 'error',
duration: 2500,
isClosable: true,
});
}
},
[isClipboardAPIAvailable, t, toaster]
);
return { isClipboardAPIAvailable, copyImageToClipboard };
};

View File

@ -1,7 +1,5 @@
import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
export type Coordinates = { export type Coordinates = {
x: number; x: number;
y: number; y: number;
@ -22,7 +20,6 @@ export interface UIState {
shouldUseCanvasBetaLayout: boolean; shouldUseCanvasBetaLayout: boolean;
shouldShowExistingModelsInSearch: boolean; shouldShowExistingModelsInSearch: boolean;
shouldUseSliders: boolean; shouldUseSliders: boolean;
addNewModelUIOption: AddNewModelType;
shouldHidePreview: boolean; shouldHidePreview: boolean;
shouldPinGallery: boolean; shouldPinGallery: boolean;
shouldShowGallery: boolean; shouldShowGallery: boolean;

View File

@ -5,7 +5,9 @@ import {
BaseModelType, BaseModelType,
CheckpointModelConfig, CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
ConvertModelConfig,
DiffusersModelConfig, DiffusersModelConfig,
ImportModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
MergeModelConfig, MergeModelConfig,
@ -13,8 +15,9 @@ import {
VaeModelConfig, VaeModelConfig,
} from 'services/api/types'; } from 'services/api/types';
import queryString from 'query-string';
import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema'; import { operations, paths } from '../schema';
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string }; export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & { export type CheckpointModelConfigEntity = CheckpointModelConfig & {
@ -62,6 +65,7 @@ type DeleteMainModelResponse = void;
type ConvertMainModelArg = { type ConvertMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
params: ConvertModelConfig;
}; };
type ConvertMainModelResponse = type ConvertMainModelResponse =
@ -75,6 +79,28 @@ type MergeMainModelArg = {
type MergeMainModelResponse = type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json']; paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
body: ImportModelConfig;
};
type ImportMainModelResponse =
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json'];
type AddMainModelArg = {
body: MainModelConfig;
};
type AddMainModelResponse =
paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json'];
export type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse =
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
@ -160,6 +186,29 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}), }),
importMainModels: build.mutation<
ImportMainModelResponse,
ImportMainModelArg
>({
query: ({ body }) => {
return {
url: `models/import`,
method: 'POST',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
return {
url: `models/add`,
method: 'POST',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation< deleteMainModels: build.mutation<
DeleteMainModelResponse, DeleteMainModelResponse,
DeleteMainModelArg DeleteMainModelArg
@ -176,10 +225,11 @@ export const modelsApi = api.injectEndpoints({
ConvertMainModelResponse, ConvertMainModelResponse,
ConvertMainModelArg ConvertMainModelArg
>({ >({
query: ({ base_model, model_name }) => { query: ({ base_model, model_name, params }) => {
return { return {
url: `models/convert/${base_model}/main/${model_name}`, url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT', method: 'PUT',
params: params,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
@ -328,6 +378,36 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {
const folderQueryStr = queryString.stringify(arg, {});
return {
url: `/models/search?${folderQueryStr}`,
};
},
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG },
];
if (result) {
tags.push(
...result.map((id) => ({
type: 'ScannedModels' as const,
id,
}))
);
}
return tags;
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
url: `/models/ckpt_confs`,
};
},
}),
}), }),
}); });
@ -339,6 +419,10 @@ export const {
useGetVaeModelsQuery, useGetVaeModelsQuery,
useUpdateMainModelsMutation, useUpdateMainModelsMutation,
useDeleteMainModelsMutation, useDeleteMainModelsMutation,
useImportMainModelsMutation,
useAddMainModelsMutation,
useConvertMainModelsMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation, useMergeMainModelsMutation,
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
} = modelsApi; } = modelsApi;

View File

@ -84,7 +84,7 @@ export type paths = {
delete: operations["del_model"]; delete: operations["del_model"];
/** /**
* Update Model * Update Model
* @description Add Model * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.
*/ */
patch: operations["update_model"]; patch: operations["update_model"];
}; };
@ -102,13 +102,6 @@ export type paths = {
*/ */
post: operations["add_model"]; post: operations["add_model"];
}; };
"/api/v1/models/rename/{base_model}/{model_type}/{model_name}": {
/**
* Rename Model
* @description Rename a model
*/
post: operations["rename_model"];
};
"/api/v1/models/convert/{base_model}/{model_type}/{model_name}": { "/api/v1/models/convert/{base_model}/{model_type}/{model_name}": {
/** /**
* Convert Model * Convert Model
@ -323,7 +316,7 @@ export type components = {
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
BaseModelType: "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; BaseModelType: "sd-1" | "sd-2";
/** BoardChanges */ /** BoardChanges */
BoardChanges: { BoardChanges: {
/** /**
@ -1226,7 +1219,7 @@ export type components = {
* @description The nodes in this graph * @description The nodes in this graph
*/ */
nodes?: { nodes?: {
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
}; };
/** /**
* Edges * Edges
@ -1269,7 +1262,7 @@ export type components = {
* @description The results of node executions * @description The results of node executions
*/ */
results: { results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
}; };
/** /**
* Errors * Errors
@ -2067,12 +2060,6 @@ export type components = {
* @default false * @default false
*/ */
tiled?: boolean; tiled?: boolean;
/**
* Fp32
* @description Decode in full precision
* @default false
*/
fp32?: boolean;
}; };
/** /**
* ImageUrlsDTO * ImageUrlsDTO
@ -2528,12 +2515,6 @@ export type components = {
* @default false * @default false
*/ */
tiled?: boolean; tiled?: boolean;
/**
* Fp32
* @description Decode in full precision
* @default false
*/
fp32?: boolean;
/** /**
* Metadata * Metadata
* @description Optional core metadata to be written to the image * @description Optional core metadata to be written to the image
@ -3354,7 +3335,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -3945,6 +3926,41 @@ export type components = {
*/ */
step?: number; step?: number;
}; };
/**
* RealESRGANInvocation
* @description Upscales an image using RealESRGAN.
*/
RealESRGANInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default realesrgan
* @enum {string}
*/
type?: "realesrgan";
/**
* Image
* @description The input image
*/
image?: components["schemas"]["ImageField"];
/**
* Model Name
* @description The Real-ESRGAN model to use
* @default RealESRGAN_x4plus.pth
* @enum {string}
*/
model_name?: "RealESRGAN_x4plus.pth" | "RealESRGAN_x4plus_anime_6B.pth" | "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth";
};
/** /**
* ResizeLatentsInvocation * ResizeLatentsInvocation
* @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8. * @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.
@ -4008,472 +4024,6 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
ResourceOrigin: "internal" | "external"; ResourceOrigin: "internal" | "external";
/**
* RestoreFaceInvocation
* @description Restores faces in an image.
*/
RestoreFaceInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default restore_face
* @enum {string}
*/
type?: "restore_face";
/**
* Image
* @description The input image
*/
image?: components["schemas"]["ImageField"];
/**
* Strength
* @description The strength of the restoration
* @default 0.75
*/
strength?: number;
};
/**
* SDXLCompelInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLCompelInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_compel
* @enum {string}
*/
type?: "sdxl_compel";
/**
* Prompt
* @description Prompt
* @default
*/
prompt?: string;
/**
* Clip1
* @description Clip to use
*/
clip1?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLLatentsToLatentsInvocation
* @description Generates latents from conditionings.
*/
SDXLLatentsToLatentsInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default l2l_sdxl
* @enum {string}
*/
type?: "l2l_sdxl";
/**
* Positive Conditioning
* @description Positive conditioning for generation
*/
positive_conditioning?: components["schemas"]["ConditioningField"];
/**
* Negative Conditioning
* @description Negative conditioning for generation
*/
negative_conditioning?: components["schemas"]["ConditioningField"];
/**
* Noise
* @description The noise to use
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
* @description The number of steps to use to generate the image
* @default 10
*/
steps?: number;
/**
* Cfg Scale
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Scheduler
* @description The scheduler to use
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Latents
* @description Initial latents
*/
latents?: components["schemas"]["LatentsField"];
/**
* Denoising Start
* @default 0
*/
denoising_start?: number;
/**
* Denoising End
* @default 1
*/
denoising_end?: number;
};
/**
* SDXLModelLoaderInvocation
* @description Loads an sdxl base model, outputting its submodels.
*/
SDXLModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_model_loader
* @enum {string}
*/
type?: "sdxl_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["MainModelField"];
};
/**
* SDXLModelLoaderOutput
* @description SDXL base model loader output
*/
SDXLModelLoaderOutput: {
/**
* Type
* @default sdxl_model_loader_output
* @enum {string}
*/
type?: "sdxl_model_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip
* @description Tokenizer and text_encoder submodels
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Tokenizer and text_encoder submodels
*/
clip2?: components["schemas"]["ClipField"];
/**
* Vae
* @description Vae submodel
*/
vae?: components["schemas"]["VaeField"];
};
/**
* SDXLRawPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRawPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_raw_prompt
* @enum {string}
*/
type?: "sdxl_raw_prompt";
/**
* Prompt
* @description Prompt
* @default
*/
prompt?: string;
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Target Width
* @default 1024
*/
target_width?: number;
/**
* Target Height
* @default 1024
*/
target_height?: number;
/**
* Clip1
* @description Clip to use
*/
clip1?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLRefinerModelLoaderInvocation
* @description Loads an sdxl refiner model, outputting its submodels.
*/
SDXLRefinerModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_model_loader
* @enum {string}
*/
type?: "sdxl_refiner_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["MainModelField"];
};
/**
* SDXLRefinerModelLoaderOutput
* @description SDXL refiner model loader output
*/
SDXLRefinerModelLoaderOutput: {
/**
* Type
* @default sdxl_refiner_model_loader_output
* @enum {string}
*/
type?: "sdxl_refiner_model_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip2
* @description Tokenizer and text_encoder submodels
*/
clip2?: components["schemas"]["ClipField"];
/**
* Vae
* @description Vae submodel
*/
vae?: components["schemas"]["VaeField"];
};
/**
* SDXLRefinerRawPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRefinerRawPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_raw_prompt
* @enum {string}
*/
type?: "sdxl_refiner_raw_prompt";
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Aesthetic Score
* @default 6
*/
aesthetic_score?: number;
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLTextToLatentsInvocation
* @description Generates latents from conditionings.
*/
SDXLTextToLatentsInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default t2l_sdxl
* @enum {string}
*/
type?: "t2l_sdxl";
/**
* Positive Conditioning
* @description Positive conditioning for generation
*/
positive_conditioning?: components["schemas"]["ConditioningField"];
/**
* Negative Conditioning
* @description Negative conditioning for generation
*/
negative_conditioning?: components["schemas"]["ConditioningField"];
/**
* Noise
* @description The noise to use
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
* @description The number of steps to use to generate the image
* @default 10
*/
steps?: number;
/**
* Cfg Scale
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Scheduler
* @description The scheduler to use
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Denoising End
* @default 1
*/
denoising_end?: number;
};
/** /**
* ScaleLatentsInvocation * ScaleLatentsInvocation
* @description Scales latents by a given factor. * @description Scales latents by a given factor.
@ -4676,56 +4226,6 @@ export type components = {
vae?: string; vae?: string;
variant: components["schemas"]["ModelVariantType"]; variant: components["schemas"]["ModelVariantType"];
}; };
/** StableDiffusionXLModelCheckpointConfig */
StableDiffusionXLModelCheckpointConfig: {
/** Model Name */
model_name: string;
base_model: components["schemas"]["BaseModelType"];
/**
* Model Type
* @enum {string}
*/
model_type: "main";
/** Path */
path: string;
/** Description */
description?: string;
/**
* Model Format
* @enum {string}
*/
model_format: "checkpoint";
error?: components["schemas"]["ModelError"];
/** Vae */
vae?: string;
/** Config */
config: string;
variant: components["schemas"]["ModelVariantType"];
};
/** StableDiffusionXLModelDiffusersConfig */
StableDiffusionXLModelDiffusersConfig: {
/** Model Name */
model_name: string;
base_model: components["schemas"]["BaseModelType"];
/**
* Model Type
* @enum {string}
*/
model_type: "main";
/** Path */
path: string;
/** Description */
description?: string;
/**
* Model Format
* @enum {string}
*/
model_format: "diffusers";
error?: components["schemas"]["ModelError"];
/** Vae */
vae?: string;
variant: components["schemas"]["ModelVariantType"];
};
/** /**
* StepParamEasingInvocation * StepParamEasingInvocation
* @description Experimental per-step parameter easing for denoising steps * @description Experimental per-step parameter easing for denoising steps
@ -4813,7 +4313,7 @@ export type components = {
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "scheduler" | "safety_checker"; SubModelType: "unet" | "text_encoder" | "tokenizer" | "vae" | "scheduler" | "safety_checker";
/** /**
* SubtractInvocation * SubtractInvocation
* @description Subtracts two numbers * @description Subtracts two numbers
@ -4986,47 +4486,6 @@ export type components = {
*/ */
loras: (components["schemas"]["LoraInfo"])[]; loras: (components["schemas"]["LoraInfo"])[];
}; };
/**
* UpscaleInvocation
* @description Upscales an image.
*/
UpscaleInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default upscale
* @enum {string}
*/
type?: "upscale";
/**
* Image
* @description The input image
*/
image?: components["schemas"]["ImageField"];
/**
* Strength
* @description The strength
* @default 0.75
*/
strength?: number;
/**
* Level
* @description The upscale level
* @default 2
* @enum {integer}
*/
level?: 2 | 4;
};
/** /**
* VAEModelField * VAEModelField
* @description Vae model field * @description Vae model field
@ -5153,24 +4612,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -5281,7 +4734,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {
@ -5318,7 +4771,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {
@ -5516,8 +4969,8 @@ export type operations = {
list_models: { list_models: {
parameters: { parameters: {
query?: { query?: {
/** @description Base models to include */ /** @description Base model */
base_models?: (components["schemas"]["BaseModelType"])[]; base_model?: components["schemas"]["BaseModelType"];
/** @description The type of model to get */ /** @description The type of model to get */
model_type?: components["schemas"]["ModelType"]; model_type?: components["schemas"]["ModelType"];
}; };
@ -5567,7 +5020,7 @@ export type operations = {
}; };
/** /**
* Update Model * Update Model
* @description Add Model * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.
*/ */
update_model: { update_model: {
parameters: { parameters: {
@ -5582,20 +5035,22 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
}; };
}; };
responses: { responses: {
/** @description The model was updated successfully */ /** @description The model was updated successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
400: never; 400: never;
/** @description The model could not be found */ /** @description The model could not be found */
404: never; 404: never;
/** @description There is already a model corresponding to the new name */
409: never;
/** @description Validation Error */ /** @description Validation Error */
422: { 422: {
content: { content: {
@ -5618,7 +5073,7 @@ export type operations = {
/** @description The model imported successfully */ /** @description The model imported successfully */
201: { 201: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
}; };
}; };
/** @description The model could not be found */ /** @description The model could not be found */
@ -5642,14 +5097,14 @@ export type operations = {
add_model: { add_model: {
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
}; };
}; };
responses: { responses: {
/** @description The model added successfully */ /** @description The model added successfully */
201: { 201: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
}; };
}; };
/** @description The model could not be found */ /** @description The model could not be found */
@ -5666,46 +5121,6 @@ export type operations = {
424: never; 424: never;
}; };
}; };
/**
* Rename Model
* @description Rename a model
*/
rename_model: {
parameters: {
query?: {
/** @description new model name */
new_name?: string;
/** @description new model base */
new_base?: components["schemas"]["BaseModelType"];
};
path: {
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
/** @description The type of model */
model_type: components["schemas"]["ModelType"];
/** @description current model name */
model_name: string;
};
};
responses: {
/** @description The model was renamed successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to the new name */
409: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/** /**
* Convert Model * Convert Model
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. * @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
@ -5729,7 +5144,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
@ -5818,7 +5233,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
}; };
}; };
/** @description Incompatible models */ /** @description Incompatible models */

View File

@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models // Models
export type ModelType = components['schemas']['ModelType']; export type ModelType = components['schemas']['ModelType'];
export type SubModelType = components['schemas']['SubModelType'];
export type BaseModelType = components['schemas']['BaseModelType']; export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField']; export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField']; export type VAEModelField = components['schemas']['VAEModelField'];
@ -57,7 +58,10 @@ export type AnyModelConfig =
| ControlNetModelConfig | ControlNetModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig; | MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models']; export type MergeModelConfig = components['schemas']['Body_merge_models'];
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
export type ImportModelConfig = components['schemas']['Body_import_model'];
// Graphs // Graphs
export type Graph = components['schemas']['Graph']; export type Graph = components['schemas']['Graph'];

View File

@ -5,6 +5,8 @@ import {
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
} from 'services/events/types'; } from 'services/events/types';
// Common socket action payload data // Common socket action payload data
@ -162,3 +164,35 @@ export const socketGeneratorProgress = createAction<
export const appSocketGeneratorProgress = createAction< export const appSocketGeneratorProgress = createAction<
BaseSocketPayload & { data: GeneratorProgressEvent } BaseSocketPayload & { data: GeneratorProgressEvent }
>('socket/appSocketGeneratorProgress'); >('socket/appSocketGeneratorProgress');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/socketModelLoadStarted');
/**
* App-level Model Load Started
*/
export const appSocketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/appSocketModelLoadStarted');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/socketModelLoadCompleted');
/**
* App-level Model Load Completed
*/
export const appSocketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/appSocketModelLoadCompleted');

View File

@ -1,5 +1,11 @@
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
import { Graph, GraphExecutionState } from '../api/types'; import {
BaseModelType,
Graph,
GraphExecutionState,
ModelType,
SubModelType,
} from '../api/types';
/** /**
* A progress image, we get one for each step in the generation * A progress image, we get one for each step in the generation
@ -25,6 +31,25 @@ export type BaseNode = {
[key: string]: AnyInvocation[keyof AnyInvocation]; [key: string]: AnyInvocation[keyof AnyInvocation];
}; };
export type ModelLoadStartedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
};
export type ModelLoadCompletedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
hash?: string;
location: string;
precision: string;
};
/** /**
* A `generator_progress` socket.io event. * A `generator_progress` socket.io event.
* *
@ -101,6 +126,8 @@ export type ServerToClientEvents = {
graph_execution_state_complete: ( graph_execution_state_complete: (
payload: GraphExecutionStateCompleteEvent payload: GraphExecutionStateCompleteEvent
) => void; ) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
}; };
export type ClientToServerEvents = { export type ClientToServerEvents = {

View File

@ -11,6 +11,8 @@ import {
socketConnected, socketConnected,
socketDisconnected, socketDisconnected,
socketSubscribed, socketSubscribed,
socketModelLoadStarted,
socketModelLoadCompleted,
} from '../actions'; } from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { ClientToServerEvents, ServerToClientEvents } from '../types';
import { Logger } from 'roarr'; import { Logger } from 'roarr';
@ -44,7 +46,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
socketSubscribed({ socketSubscribed({
sessionId, sessionId,
timestamp: getTimestamp(), timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId, boardId: getState().gallery.selectedBoardId,
}) })
); );
} }
@ -118,4 +120,28 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
}) })
); );
}); });
/**
* Model load started
*/
socket.on('model_load_started', (data) => {
dispatch(
socketModelLoadStarted({
data,
timestamp: getTimestamp(),
})
);
});
/**
* Model load completed
*/
socket.on('model_load_completed', (data) => {
dispatch(
socketModelLoadCompleted({
data,
timestamp: getTimestamp(),
})
);
});
}; };

View File

@ -16,7 +16,7 @@ const invokeAI = defineStyle((props) => {
}; };
return { return {
bg: mode('base.200', 'base.600')(props), bg: mode('base.250', 'base.600')(props),
color: mode('base.850', 'base.100')(props), color: mode('base.850', 'base.100')(props),
borderRadius: 'base', borderRadius: 'base',
svg: { svg: {

View File

@ -1,6 +1,7 @@
import { menuAnatomy } from '@chakra-ui/anatomy'; import { menuAnatomy } from '@chakra-ui/anatomy';
import { createMultiStyleConfigHelpers } from '@chakra-ui/react'; import { createMultiStyleConfigHelpers } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools'; import { mode } from '@chakra-ui/theme-tools';
import { MotionProps } from 'framer-motion';
const { definePartsStyle, defineMultiStyleConfig } = const { definePartsStyle, defineMultiStyleConfig } =
createMultiStyleConfigHelpers(menuAnatomy.keys); createMultiStyleConfigHelpers(menuAnatomy.keys);
@ -21,6 +22,7 @@ const invokeAI = definePartsStyle((props) => ({
}, },
list: { list: {
zIndex: 9999, zIndex: 9999,
color: mode('base.900', 'base.150')(props),
bg: mode('base.200', 'base.800')(props), bg: mode('base.200', 'base.800')(props),
shadow: 'dark-lg', shadow: 'dark-lg',
border: 'none', border: 'none',
@ -35,6 +37,9 @@ const invokeAI = definePartsStyle((props) => ({
_focus: { _focus: {
bg: mode('base.400', 'base.600')(props), bg: mode('base.400', 'base.600')(props),
}, },
svg: {
opacity: 0.5,
},
}, },
})); }));
@ -46,3 +51,28 @@ export const menuTheme = defineMultiStyleConfig({
variant: 'invokeAI', variant: 'invokeAI',
}, },
}); });
export const menuListMotionProps: MotionProps = {
variants: {
enter: {
visibility: 'visible',
opacity: 1,
scale: 1,
transition: {
duration: 0.07,
ease: [0.4, 0, 0.2, 1],
},
},
exit: {
transitionEnd: {
visibility: 'hidden',
},
opacity: 0,
scale: 0.8,
transition: {
duration: 0.07,
easings: 'easeOut',
},
},
},
};

File diff suppressed because it is too large Load Diff

View File

@ -1 +1 @@
__version__ = "3.0.0+b5" __version__ = "3.0.0+b6"

45
pull_request_template.md Normal file
View File

@ -0,0 +1,45 @@
## What type of PR is this? (check all applicable)
- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:
## Description
## Related Tickets & Documents
<!--
For pull requests that relate or close an issue, please include them
below.
For example having the text: "closes #1234" would connect the current pull
request to issue 1234. And when we merge the pull request, Github will
automatically close the issue.
-->
- Related Issue #
- Closes #
## QA Instructions, Screenshots, Recordings
<!--
Please provide steps on how to test changes, any hardware or
software specifications as well as any other pertinent information.
-->
## Added/updated tests?
- [ ] Yes
- [ ] No : _please replace this line with details on why tests
have not been included_
## [optional] Are there any post deployment tasks we need to perform?

View File

@ -55,7 +55,6 @@ def mock_services() -> InvocationServices:
), ),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore configuration = None, # type: ignore
) )

View File

@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
), ),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore configuration = None, # type: ignore
) )

View File

@ -1,6 +1,6 @@
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation from invokeai.app.invocations.upscale import RealESRGANInvocation
from invokeai.app.invocations.image import * from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.params import ParamIntInvocation from invokeai.app.invocations.params import ParamIntInvocation
@ -19,7 +19,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
def test_connections_are_compatible(): def test_connections_are_compatible():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image" from_field = "image"
to_node = UpscaleInvocation(id = "2") to_node = RealESRGANInvocation(id = "2")
to_field = "image" to_field = "image"
result = are_connections_compatible(from_node, from_field, to_node, to_field) result = are_connections_compatible(from_node, from_field, to_node, to_field)
@ -29,7 +29,7 @@ def test_connections_are_compatible():
def test_connections_are_incompatible(): def test_connections_are_incompatible():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image" from_field = "image"
to_node = UpscaleInvocation(id = "2") to_node = RealESRGANInvocation(id = "2")
to_field = "strength" to_field = "strength"
result = are_connections_compatible(from_node, from_field, to_node, to_field) result = are_connections_compatible(from_node, from_field, to_node, to_field)
@ -39,7 +39,7 @@ def test_connections_are_incompatible():
def test_connections_incompatible_with_invalid_fields(): def test_connections_incompatible_with_invalid_fields():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "invalid_field" from_field = "invalid_field"
to_node = UpscaleInvocation(id = "2") to_node = RealESRGANInvocation(id = "2")
to_field = "image" to_field = "image"
# From field is invalid # From field is invalid
@ -86,10 +86,10 @@ def test_graph_fails_to_update_node_if_type_changes():
g = Graph() g = Graph()
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n2) g.add_node(n2)
nu = UpscaleInvocation(id = "1") nu = RealESRGANInvocation(id = "1")
with pytest.raises(TypeError): with pytest.raises(TypeError):
g.update_node("1", nu) g.update_node("1", nu)
@ -98,7 +98,7 @@ def test_graph_allows_non_conflicting_id_change():
g = Graph() g = Graph()
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n2) g.add_node(n2)
e1 = create_edge(n.id,"image",n2.id,"image") e1 = create_edge(n.id,"image",n2.id,"image")
g.add_edge(e1) g.add_edge(e1)
@ -128,7 +128,7 @@ def test_graph_fails_to_update_node_id_if_conflict():
def test_graph_adds_edge(): def test_graph_adds_edge():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image") e = create_edge(n1.id,"image",n2.id,"image")
@ -139,7 +139,7 @@ def test_graph_adds_edge():
def test_graph_fails_to_add_edge_with_cycle(): def test_graph_fails_to_add_edge_with_cycle():
g = Graph() g = Graph()
n1 = UpscaleInvocation(id = "1") n1 = RealESRGANInvocation(id = "1")
g.add_node(n1) g.add_node(n1)
e = create_edge(n1.id,"image",n1.id,"image") e = create_edge(n1.id,"image",n1.id,"image")
with pytest.raises(InvalidEdgeError): with pytest.raises(InvalidEdgeError):
@ -148,8 +148,8 @@ def test_graph_fails_to_add_edge_with_cycle():
def test_graph_fails_to_add_edge_with_long_cycle(): def test_graph_fails_to_add_edge_with_long_cycle():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
n3 = UpscaleInvocation(id = "3") n3 = RealESRGANInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
g.add_node(n3) g.add_node(n3)
@ -164,7 +164,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
def test_graph_fails_to_add_edge_with_missing_node_id(): def test_graph_fails_to_add_edge_with_missing_node_id():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
e1 = create_edge("1","image","3","image") e1 = create_edge("1","image","3","image")
@ -177,8 +177,8 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
def test_graph_fails_to_add_edge_when_destination_exists(): def test_graph_fails_to_add_edge_when_destination_exists():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
n3 = UpscaleInvocation(id = "3") n3 = RealESRGANInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
g.add_node(n3) g.add_node(n3)
@ -194,7 +194,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
def test_graph_fails_to_add_edge_with_mismatched_types(): def test_graph_fails_to_add_edge_with_mismatched_types():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
e1 = create_edge("1","image","2","strength") e1 = create_edge("1","image","2","strength")
@ -344,7 +344,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
def test_graph_validates(): def test_graph_validates():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
e1 = create_edge("1","image","2","image") e1 = create_edge("1","image","2","image")
@ -377,8 +377,8 @@ def test_graph_invalid_if_subgraph_invalid():
def test_graph_invalid_if_has_cycle(): def test_graph_invalid_if_has_cycle():
g = Graph() g = Graph()
n1 = UpscaleInvocation(id = "1") n1 = RealESRGANInvocation(id = "1")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.nodes[n1.id] = n1 g.nodes[n1.id] = n1
g.nodes[n2.id] = n2 g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","image") e1 = create_edge("1","image","2","image")
@ -391,7 +391,7 @@ def test_graph_invalid_if_has_cycle():
def test_graph_invalid_with_invalid_connection(): def test_graph_invalid_with_invalid_connection():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.nodes[n1.id] = n1 g.nodes[n1.id] = n1
g.nodes[n2.id] = n2 g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","strength") e1 = create_edge("1","image","2","strength")
@ -503,7 +503,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
g.add_node(n1) g.add_node(n1)
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n2) g.add_node(n2)
with pytest.raises(NodeNotFoundError): with pytest.raises(NodeNotFoundError):
@ -512,7 +512,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
def test_graph_gets_networkx_graph(): def test_graph_gets_networkx_graph():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image") e = create_edge(n1.id,"image",n2.id,"image")
@ -529,7 +529,7 @@ def test_graph_gets_networkx_graph():
def test_graph_can_serialize(): def test_graph_can_serialize():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image") e = create_edge(n1.id,"image",n2.id,"image")
@ -541,7 +541,7 @@ def test_graph_can_serialize():
def test_graph_can_deserialize(): def test_graph_can_deserialize():
g = Graph() g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = RealESRGANInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image") e = create_edge(n1.id,"image",n2.id,"image")