merge with main

This commit is contained in:
Lincoln Stein 2023-07-18 07:00:55 -04:00
commit 9c3c393b84
98 changed files with 5874 additions and 7975 deletions

View File

@ -132,8 +132,10 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are
not supported.
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
@ -197,11 +199,18 @@ not supported.
7. Launch the web server (do it every time you run InvokeAI):
```terminal
invokeai --web
invokeai-web
```
8. Point your browser to http://localhost:9090 to bring up the web interface.
9. Type `banana sushi` in the box on the top left and click `Invoke`.
8. Build Node.js assets
```terminal
cd invokeai/frontend/web/
yarn vite build
```
9. Point your browser to http://localhost:9090 to bring up the web interface.
10. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`.

View File

@ -81,3 +81,193 @@ pytest --cov; open ./coverage/html/index.html
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md"
## Developing InvokeAI in VSCode
VSCode offers some nice tools:
- python debugger
- automatic `venv` activation
- remote dev (e.g. run InvokeAI on a beefy linux desktop while you type in
comfort on your macbook)
### Setup
You'll need the
[Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python)
and
[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance)
extensions installed first.
It's also really handy to install the `Jupyter` extensions:
- [Jupyter](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)
- [Jupyter Cell Tags](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-cell-tags)
- [Jupyter Notebook Renderers](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter-renderers)
- [Jupyter Slide Show](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.vscode-jupyter-slideshow)
#### InvokeAI workspace
Creating a VSCode workspace for working on InvokeAI is highly recommended. It
can hold InvokeAI-specific settings and configs.
To make a workspace:
- Open the InvokeAI repo dir in VSCode
- `File` > `Save Workspace As` > save it _outside_ the repo
#### Default python interpreter (i.e. automatic virtual environment activation)
- Use command palette to run command
`Preferences: Open Workspace Settings (JSON)`
- Add `python.defaultInterpreterPath` to `settings`, pointing to your `venv`'s
python
Should look something like this:
```jsonc
{
// I like to have all InvokeAI-related folders in my workspace
"folders": [
{
// repo root
"path": "InvokeAI"
},
{
// InvokeAI root dir, where `invokeai.yaml` lives
"path": "/path/to/invokeai_root"
}
],
"settings": {
// Where your InvokeAI `venv`'s python executable lives
"python.defaultInterpreterPath": "/path/to/invokeai_root/.venv/bin/python"
}
}
```
Now when you open the VSCode integrated terminal, or do anything that needs to
run python, it will automatically be in your InvokeAI virtual environment.
Bonus: When you create a Jupyter notebook, when you run it, you'll be prompted
for the python interpreter to run in. This will default to your `venv` python,
and so you'll have access to the same python environment as the InvokeAI app.
This is _super_ handy.
#### Debugging configs with `launch.json`
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
these can be scoped to a workspace or folder.
Follow the [official guide](https://code.visualstudio.com/docs/python/debugging)
to set up your `launch.json` and try it out.
Now we can create the InvokeAI debugging configs:
```jsonc
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
// Run the InvokeAI backend & serve the pre-built UI
"name": "InvokeAI Web",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-web.py",
"args": [
// Your InvokeAI root dir (where `invokeai.yaml` lives)
"--root",
"/path/to/invokeai_root",
// Access the app from anywhere on your local network
"--host",
"0.0.0.0"
],
"justMyCode": true
},
{
// Run the nodes-based CLI
"name": "InvokeAI CLI",
"type": "python",
"request": "launch",
"program": "scripts/invokeai-cli.py",
"justMyCode": true
},
{
// Run tests
"name": "InvokeAI Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["--capture=no"],
"justMyCode": true
},
{
// Run a single test
"name": "InvokeAI Single Test",
"type": "python",
"request": "launch",
"module": "pytest",
"args": [
// Change this to point to the specific test you are working on
"tests/nodes/test_invoker.py"
],
"justMyCode": true
},
{
// This is the default, useful to just run a single file
"name": "Python: File",
"type": "python",
"request": "launch",
"program": "${file}",
"justMyCode": true
}
]
}
```
You'll see these configs in the debugging configs drop down. Running them will
start InvokeAI with attached debugger, in the correct environment, and work just
like the normal app.
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).
#### Remote dev
This is very easy to set up and provides the same very smooth experience as
local development. Environments and debugging, as set up above, just work,
though you'd need to recreate the workspace and debugging configs on the remote.
Consult the
[official guide](https://code.visualstudio.com/docs/remote/remote-overview) to
get it set up.
Suggest using VSCode's included settings sync so that your remote dev host has
all the same app settings and extensions automagically.
##### One remote dev gotcha
I've found the automatic port forwarding to be very flakey. You can disable it
in `Preferences: Open Remote Settings (ssh: hostname)`. Search for
`remote.autoForwardPorts` and untick the box.
To forward ports very reliably, use SSH on the remote dev client (e.g. your
macbook). Here's how to forward both backend API port (`9090`) and the frontend
live dev server port (`5173`):
```bash
ssh \
-L 9090:localhost:9090 \
-L 5173:localhost:5173 \
user@remote-dev-host
```
The forwarding stops when you close the terminal window, so suggest to do this
_outside_ the VSCode integrated terminal in case you need to restart VSCode for
an extension update or something
Now, on your remote dev client, you can open `localhost:9090` and access the UI,
now served from the remote dev host, just the same as if it was running on the
client.

View File

@ -1,4 +1,8 @@
# Nodes Editor (Experimental Beta)
# Nodes Editor (Experimental)
🚨
*The node editor is experimental. We've made it accessible because we use it to develop the application, but we have not addressed the many known rough edges. It's very easy to shoot yourself in the foot, and we cannot offer support for it until it sees full release (ETA v3.1). Everything is subject to change without warning.*
🚨
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.

View File

@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
@ -56,7 +57,7 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}")

View File

@ -13,8 +13,11 @@ from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
ModelNotFoundException,
InvalidModelException,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -34,11 +37,16 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
if base_models and len(base_models)>0:
models_raw = list()
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@ -46,8 +54,9 @@ async def list_models(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"},
400: {"description" : "Bad request"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code = 200,
response_model = UpdateModelResponse,
@ -58,23 +67,58 @@ async def update_model(
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
""" Add Model """
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger
try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = info.model_name,
new_base = info.base_model,
)
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get('path')
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
return model_response
@ -85,6 +129,7 @@ async def update_model(
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
415: {"description" : "Unrecognized file/folder format"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
@ -111,7 +156,7 @@ async def import_model(
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
raise HTTPException(status_code=415)
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
@ -121,9 +166,12 @@ async def import_model(
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@ -161,57 +209,13 @@ async def add_model(
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
@ -238,9 +242,9 @@ async def delete_model(
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
@ -273,8 +277,8 @@ async def convert_model(
base_model = base_model,
model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@ -364,8 +368,55 @@ async def merge_models(
model_type = ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
# The rename operation is now supported by update_model and no longer needs to be
# a standalone route.
# @models_router.post(
# "/rename/{base_model}/{model_type}/{model_name}",
# operation_id="rename_model",
# responses= {
# 201: {"description" : "The model was renamed successfully"},
# 404: {"description" : "The model could not be found"},
# 409: {"description" : "There is already a model corresponding to the new name"},
# },
# status_code=201,
# response_model=ImportModelResponse
# )
# async def rename_model(
# base_model: BaseModelType = Path(description="Base model"),
# model_type: ModelType = Path(description="The type of model"),
# model_name: str = Path(description="current model name"),
# new_name: Optional[str] = Query(description="new model name", default=None),
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
# ) -> ImportModelResponse:
# """ Rename a model"""
# logger = ApiDependencies.invoker.services.logger
# try:
# result = ApiDependencies.invoker.services.model_manager.rename_model(
# base_model = base_model,
# model_type = model_type,
# model_name = model_name,
# new_name = new_name,
# new_base = new_base,
# )
# logger.debug(result)
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
# model_name=new_name or model_name,
# base_model=new_base or base_model,
# model_type=model_type
# )
# return parse_obj_as(ImportModelResponse, model_raw)
# except ModelNotFoundException as e:
# logger.error(str(e))
# raise HTTPException(status_code=404, detail=str(e))
# except ValueError as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))

View File

@ -39,6 +39,7 @@ from .invocations.baseinvocation import BaseInvocation
import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes

View File

@ -57,6 +57,7 @@ from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes

View File

@ -1,8 +1,8 @@
from typing import Literal, Optional, Union, List
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re
import torch
from compel import Compel
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
@ -14,6 +14,7 @@ from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import ClipField
from dataclasses import dataclass
class ConditioningField(BaseModel):
@ -23,6 +24,34 @@ class ConditioningField(BaseModel):
class Config:
schema_extra = {"required": ["conditioning_name"]}
@dataclass
class BasicConditioningInfo:
#type: Literal["basic_conditioning"] = "basic_conditioning"
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
ConditioningInfoType = Annotated[
Union[BasicConditioningInfo, SDXLConditioningInfo],
Field(discriminator="type")
]
@dataclass
class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
#unconditioned: Optional[torch.Tensor]
#class ConditioningAlgo(str, Enum):
# Compose = "compose"
# ComposeEx = "compose_ex"
# PerpNeg = "perp_neg"
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
@ -57,10 +86,10 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
**self.clip.tokenizer.dict(), context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
**self.clip.text_encoder.dict(), context=context,
)
def _lora_loader():
@ -82,6 +111,7 @@ class CompelInvocation(BaseInvocation):
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:
@ -118,10 +148,17 @@ class CompelInvocation(BaseInvocation):
cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
conditioning_data = ConditioningFieldData(
conditionings=[
BasicConditioningInfo(
embeds=c,
extra_conditioning=ec,
)
]
)
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
@ -129,6 +166,389 @@ class CompelInvocation(BaseInvocation):
),
)
class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
if get_pooled:
c_pooled = prompt_embeds[0]
else:
c_pooled = None
c = prompt_embeds.hidden_states[-2]
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=True,
)
conjunction = Compel.parse_prompt_string(prompt)
if context.services.configuration.log_tokenization:
# TODO: better logging for and syntax
for prompt_obj in conjunction.prompts:
log_tokenization_for_prompt_object(prompt_obj, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
if get_pooled:
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
else:
c_pooled = None
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
return c, c_pooled, ec
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
prompt: str = Field(default="", description="Prompt")
style: str = Field(default="", description="Style prompt")
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
target_width: int = Field(1024, description="")
target_height: int = Field(1024, description="")
clip1: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip1, self.prompt, False)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec1,
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
aesthetic_score: float = Field(6.0, description="")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Pass unmodified prompt to conditioning without compel processing."""
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
prompt: str = Field(default="", description="Prompt")
style: str = Field(default="", description="Style prompt")
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
target_width: int = Field(1024, description="")
target_height: int = Field(1024, description="")
clip1: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip1, self.prompt, False)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec1,
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
aesthetic_score: float = Field(6.0, description="")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"

View File

@ -157,13 +157,13 @@ class InpaintInvocation(BaseInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"}), context=context,)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
with vae_info as vae,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\

View File

@ -31,6 +31,13 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
@ -76,7 +83,7 @@ def get_scheduler(
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict()
**scheduler_info.dict(), context=context,
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -161,12 +168,12 @@ class TextToLatentsInvocation(BaseInvocation):
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name
)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -262,6 +269,7 @@ class TextToLatentsInvocation(BaseInvocation):
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
@ -313,14 +321,14 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})
**lora.dict(exclude={"weight"}), context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
**self.unet.unet.dict(), context=context,
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -403,14 +411,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})
**lora.dict(exclude={"weight"}), context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
**self.unet.unet.dict(), context=context,
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -475,8 +483,8 @@ class LatentsToImageInvocation(BaseInvocation):
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(False, description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# Schema customisation
class Config(InvocationConfig):
@ -491,10 +499,35 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
**self.vae.vae.dict(), context=context,
)
with vae_info as vae:
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.services.configuration.tiled_decode:
vae.enable_tiling()
else:
@ -618,6 +651,8 @@ class ImageToLatentsInvocation(BaseInvocation):
tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(False, description="Decode in full precision")
# Schema customisation
class Config(InvocationConfig):
@ -636,7 +671,7 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
**self.vae.vae.dict(), context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -644,6 +679,32 @@ class ImageToLatentsInvocation(BaseInvocation):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
with vae_info as vae:
orig_dtype = vae.dtype
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
#else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
#latents = latents.half()
if self.tiled:
vae.enable_tiling()
else:
@ -658,6 +719,7 @@ class ImageToLatentsInvocation(BaseInvocation):
) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents
latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)

View File

@ -33,7 +33,6 @@ class ClipField(BaseModel):
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
@ -50,7 +49,6 @@ class ModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class MainModelField(BaseModel):
"""Main model field"""
@ -64,7 +62,6 @@ class LoRAModelField(BaseModel):
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@ -157,6 +154,22 @@ class MainModelLoaderInvocation(BaseInvocation):
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
@ -167,7 +180,7 @@ class MainModelLoaderInvocation(BaseInvocation):
),
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""

View File

@ -1,6 +1,8 @@
from typing import Literal
from os.path import exists
from typing import Literal, Optional
from pydantic.fields import Field
import numpy as np
from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
@ -55,3 +57,41 @@ class DynamicPromptInvocation(BaseInvocation):
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
class PromptsFromFileInvocation(BaseInvocation):
'''Loads prompts from a text file'''
# fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs
file_path: str = Field(description="Path to prompt text file")
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
#fmt: on
@validator("file_path")
def file_path_exists(cls, v):
if not exists(v):
raise ValueError(FileNotFoundError)
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
if max_prompts <= 0:
end_line = np.iinfo(np.int32).max
with open(file_path) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
if i >= end_line:
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -0,0 +1,658 @@
import torch
import inspect
from tqdm import tqdm
from typing import List, Literal, Optional, Union
from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
# fmt: off
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
# fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
#fmt: on
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
model: MainModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Model Loader",
"tags": ["model", "loader", "sdxl"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
model: MainModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
# Text to image
class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_sdxl"] = "t2l_sdxl"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.noise.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
prompt_embeds = positive_cond_data.conditionings[0].embeds
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
latents = latents * scheduler.init_noise_sigma
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
generator=torch.Generator(device=unet.device).manual_seed(0),
)
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
# apply denoising_end
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
if not context.services.configuration.sequential_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_uncond
#del noise_pred_text
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# predict the noise residual
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
noise_pred_uncond = unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_text = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_text
#del noise_pred_uncond
#import gc
#gc.collect()
#torch.cuda.empty_cache()
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
#del noise_pred
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
class SDXLLatentsToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["l2l_sdxl"] = "l2l_sdxl"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
prompt_embeds = positive_cond_data.conditionings[0].embeds
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order:]
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
# apply scheduler extra args
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
generator=torch.Generator(device=unet.device).manual_seed(0),
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
# apply denoising_end
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
if not context.services.configuration.sequential_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_uncond
#del noise_pred_text
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# predict the noise residual
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_uncond = unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_text = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_text
#del noise_pred_uncond
#import gc
#gc.collect()
#torch.cuda.empty_cache()
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
#del noise_pred
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Literal, Union, cast
import cv2 as cv
@ -33,12 +33,12 @@ class RealESRGANInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) # type: ignore
models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore
image = context.services.images.get_pil_image(self.image.image_name)
models_path = context.services.configuration.models_path
rrdbnet_model = None
netscale = None
model_path = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@ -54,13 +54,13 @@ class RealESRGANInvocation(BaseInvocation):
scale=4,
)
netscale = 4
elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth":
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
@ -83,11 +83,11 @@ class RealESRGANInvocation(BaseInvocation):
context.services.logger.error(msg)
raise ValueError(msg)
model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_dir / model_path),
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
)

View File

@ -105,8 +105,6 @@ class EventServiceBase:
def emit_model_load_started (
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
@ -117,8 +115,6 @@ class EventServiceBase:
event_name="model_load_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -129,8 +125,6 @@ class EventServiceBase:
def emit_model_load_completed(
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
@ -142,12 +136,12 @@ class EventServiceBase:
event_name="model_load_completed",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
hash=model_info.hash,
location=model_info.location,
precision=str(model_info.precision),
),
)

View File

@ -12,7 +12,7 @@ if TYPE_CHECKING:
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
@ -23,7 +23,7 @@ class InvocationServices:
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC"
configuration: "InvokeAISettings"
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
@ -38,7 +38,7 @@ class InvocationServices:
self,
board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC",
configuration: "InvokeAISettings",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],

View File

@ -18,6 +18,7 @@ from invokeai.backend.model_management import (
SchedulerPredictionType,
ModelMerger,
MergeInterpolationMethod,
ModelNotFoundException,
)
from invokeai.backend.model_management.model_search import FindModels
@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC):
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist.
ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
@ -338,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
@ -346,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
part (such as the vae) of a diffusers mode.
"""
# if we are called from within a node, then we get to emit
# load start and complete events
if node and context:
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
base_model=base_model,
@ -365,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
submodel,
)
if node and context:
if context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
base_model=base_model,
@ -451,14 +448,14 @@ class ModelManagerService(ModelManagerServiceBase):
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist.
ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}")
raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
@ -509,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event(
self,
node,
context,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info:
context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -535,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
else:
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,

View File

@ -212,7 +212,7 @@ class ModelInstall(object):
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan
elif path.is_dir():

View File

@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -552,7 +552,7 @@ class ModelManager(object):
model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}')
raise ModelNotFoundException(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
@ -568,6 +568,9 @@ class ModelManager(object):
model_type=cur_model_type,
)
# expose paths as absolute to help web UI
if path := model_dict.get('path'):
model_dict['path'] = str(self.app_config.root_path / path)
models.append(model_dict)
return models
@ -596,7 +599,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None)
if model_cfg is None:
raise KeyError(f"Unknown model {model_key}")
raise ModelNotFoundException(f"Unknown model {model_key}")
# note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, [])
@ -635,6 +638,10 @@ class ModelManager(object):
The returned dict has the same format as the dict returned by
model_info().
"""
# relativize paths as they go in - this makes it easier to move the root directory around
if path := model_attributes.get('path'):
if Path(path).is_relative_to(self.app_config.root_path):
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes)
@ -689,7 +696,7 @@ class ModelManager(object):
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None)
if not model_cfg:
raise KeyError(f"Unknown model: {model_key}")
raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name
@ -700,7 +707,7 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
@ -964,7 +971,7 @@ class ModelManager(object):
that model.
May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
'''
# avoid circular import here

View File

@ -12,6 +12,7 @@ from picklescan.scanner import scan_file_path
from .models import (
BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings,
InvalidModelException
)
from .models.base import read_checkpoint_meta
@ -38,6 +39,8 @@ class ModelProbe(object):
CLASS2TYPE = {
'StableDiffusionPipeline' : ModelType.Main,
'StableDiffusionXLPipeline' : ModelType.Main,
'StableDiffusionXLImg2ImgPipeline' : ModelType.Main,
'AutoencoderKL' : ModelType.Vae,
'ControlNetModel' : ModelType.ControlNet,
}
@ -59,7 +62,7 @@ class ModelProbe(object):
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise ValueError("model parameter {model} is neither a Path, nor a model")
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(cls,
@ -99,9 +102,10 @@ class ModelProbe(object):
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction),
format = format,
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction \
) else 512,
image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \
768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction ) else \
512
)
except Exception:
raise
@ -138,7 +142,7 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise ValueError(f"Unable to determine model type for {model_path}")
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
@ -168,7 +172,7 @@ class ModelProbe(object):
return type
# give up
raise ValueError(f"Unable to determine model type for {folder_path}")
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
@ -237,7 +241,7 @@ class CheckpointProbeBase(ProbeBase):
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
@ -248,7 +252,10 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise ValueError("Cannot determine base type")
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()
@ -329,7 +336,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
########################################################
# classes for probing folders
@ -360,8 +367,12 @@ class PipelineFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf['cross_attention_dim'] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf['cross_attention_dim'] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model:
@ -418,7 +429,7 @@ class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'config.json'
if not config_file.exists():
raise ValueError(f"Cannot determine base type for {self.folder_path}")
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file,'r') as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
@ -435,7 +446,7 @@ class LoRAFolderProbe(FolderProbeBase):
model_file = base_file
break
if not model_file:
raise ValueError('Unknown LoRA format encountered')
raise InvalidModelException('Unknown LoRA format encountered')
return LoRACheckpointProbe(model_file,None).get_base_type()
############## register probe classes ######

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .sdxl import StableDiffusionXLModel
from .vae import VaeModel
from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
@ -24,6 +25,22 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,

View File

@ -24,6 +24,8 @@ class ModelNotFoundException(Exception):
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
#Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
@ -36,7 +38,9 @@ class ModelType(str, Enum):
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"

View File

@ -0,0 +1,114 @@
import os
import json
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
from .base import (
ModelConfigBase,
BaseModelType,
ModelType,
ModelVariantType,
DiffusersModel,
read_checkpoint_meta,
classproperty,
)
from omegaconf import OmegaConf
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
assert model_type == ModelType.Main
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusionXL,
model_type=ModelType.Main,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusionXLModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# TO DO: implement picking
pass
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusionXLModelFormat.Diffusers
else:
return StableDiffusionXLModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if isinstance(config, cls.CheckpointConfig):
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
else:
return model_path

View File

@ -5,14 +5,11 @@ from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
@ -248,6 +245,12 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
# note that these .yaml files don't yet exist!
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "xl-inference-v.yaml",
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
ModelVariantType.Depth: "xl-midas-inference.yaml",
}
}
@ -263,6 +266,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
# TODO: rework
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],

View File

@ -68,7 +68,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
return None
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -16,6 +16,7 @@ from .base import (
calc_model_size_by_data,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available

View File

@ -221,7 +221,7 @@ class ControlNetData:
control_mode: str = Field(default="balanced")
@dataclass(frozen=True)
@dataclass
class ConditioningData:
unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor
@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise,
run_id=run_id,
callback=callback,
**kwargs,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps,
conditioning_data,
noise=noise,
additional_guidance=additional_guidance,
run_id=run_id,
callback=callback,
additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
callback=callback,
)
return result.latents, result.attention_map_saver
@ -505,7 +502,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
@ -546,7 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
)
latents = step_output.prev_sample
@ -588,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@ -632,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
@ -649,6 +646,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)

View File

@ -237,11 +237,7 @@ class InvokeAIDiffuserComponent:
)
return latents
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
@ -266,16 +262,24 @@ class InvokeAIDiffuserComponent:
return cond, encoder_attention_mask
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
both_conditionings = torch.cat([unconditioning, conditioning])
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
unconditioning, conditioning
)
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask,
@ -293,8 +297,32 @@ class InvokeAIDiffuserComponent:
**kwargs,
):
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
# TODO: looks unused
@ -328,6 +356,20 @@ class InvokeAIDiffuserComponent:
):
context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
@ -340,6 +382,8 @@ class InvokeAIDiffuserComponent:
sigma,
unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
@ -352,6 +396,8 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
)
return unconditioned_next_x, conditioned_next_x

View File

@ -0,0 +1,634 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel

View File

@ -1,4 +1,10 @@
# This file predefines a few models that the user may want to install.
sd-1/main/stable-diffusion-xdl-base:
description: Stable Diffusion XL base model - NOT YET RELEASED!! (70 GB)
repo_id: stabilityai/stable-diffusion-xl-base
sd-1/main/stable-diffusion-xdl-refiner:
description: Stable Diffusion XL refiner model - NOT YET RELEASED!! (60 GB)
repo_id: stabilityai/stable-diffusion-xl-refiner
sd-1/main/stable-diffusion-v1-5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -399,6 +399,8 @@
"deleteModel": "Delete Model",
"deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"modelDeleted": "Model Deleted",
"modelDeleteFailed": "Failed to delete model",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
@ -408,11 +410,13 @@
"convertToDiffusers": "Convert To Diffusers",
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"noCustomLocationProvided": "No Custom Location Provided",
"convertingModelBegin": "Converting Model. Please wait.",
"v1": "v1",
"v2_base": "v2 (512px)",
"v2_768": "v2 (768px)",
@ -450,7 +454,8 @@
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
"selectModel": "Select Model",
"importModels": "Import Models"
},
"parameters": {
"general": "General",
@ -572,6 +577,7 @@
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied",
"problemCopyingImage": "Unable to Copy Image",
"imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded",
@ -688,6 +694,15 @@
"reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes",
"loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes"
"clearNodes": "Clear Nodes",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View",
"hideGraphNodes": "Hide Graph Overlay",
"showGraphNodes": "Show Graph Overlay",
"hideLegendNodes": "Hide Field Type Legend",
"showLegendNodes": "Show Field Type Legend",
"hideMinimapnodes": "Hide MiniMap",
"showMinimapnodes": "Show MiniMap"
}
}

View File

@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
export const listenerMiddleware = createListenerMiddleware();
@ -177,6 +179,8 @@ addSocketConnectedListener();
addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();
addModelLoadStartedEventListener();
addModelLoadCompletedEventListener();
// Session Created
addSessionCreatedPendingListener();

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadCompleted,
socketModelLoadCompleted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadCompletedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadCompleted(action.payload));
},
});
};

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadStarted,
socketModelLoadStarted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadStartedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load started (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadStarted(action.payload));
},
});
};

View File

@ -21,6 +21,7 @@ import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
@ -49,6 +50,7 @@ const allReducers = {
dynamicPrompts: dynamicPromptsReducer,
imageDeletion: imageDeletionReducer,
lora: loraReducer,
modelmanager: modelmanagerReducer,
[api.reducerPath]: api.reducer,
};
@ -67,6 +69,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'controlNet',
'dynamicPrompts',
'lora',
'modelmanager',
];
export const store = configureStore({

View File

@ -21,6 +21,7 @@ import { ImageDTO } from 'services/api/types';
import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
type IAIDndImageProps = {
imageDTO: ImageDTO | undefined;
@ -96,119 +97,124 @@ const IAIDndImage = (props: IAIDndImageProps) => {
};
return (
<Flex
sx={{
width: 'full',
height: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}}
>
{imageDTO && (
<ImageContextMenu imageDTO={imageDTO}>
{(ref) => (
<Flex
ref={ref}
sx={{
w: 'full',
h: 'full',
position: fitContainer ? 'absolute' : 'relative',
width: 'full',
height: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'relative',
minW: minSize ? minSize : undefined,
minH: minSize ? minSize : undefined,
userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}}
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
{imageDTO && (
<Flex
sx={{
boxSize: 16,
w: 'full',
h: 'full',
position: fitContainer ? 'absolute' : 'relative',
alignItems: 'center',
justifyContent: 'center',
}}
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
// If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx,
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex
sx={{
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: mode('base.500', 'base.500')(colorMode),
...uploadButtonStyles,
}}
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
</Flex>
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
</Flex>
</>
)}
</Flex>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
)}
</Flex>
</ImageContextMenu>
);
};

View File

@ -8,19 +8,34 @@ import {
import { useAppDispatch } from 'app/store/storeHooks';
import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { ChangeEvent, KeyboardEvent, memo, useCallback } from 'react';
import {
CSSProperties,
ChangeEvent,
KeyboardEvent,
memo,
useCallback,
} from 'react';
interface IAIInputProps extends InputProps {
label?: string;
labelPos?: 'top' | 'side';
value?: string;
size?: string;
onChange?: (e: ChangeEvent<HTMLInputElement>) => void;
formControlProps?: Omit<FormControlProps, 'isInvalid' | 'isDisabled'>;
}
const labelPosVerticalStyle: CSSProperties = {
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
gap: 10,
};
const IAIInput = (props: IAIInputProps) => {
const {
label = '',
labelPos = 'top',
isDisabled = false,
isInvalid,
formControlProps,
@ -51,6 +66,7 @@ const IAIInput = (props: IAIInputProps) => {
isInvalid={isInvalid}
isDisabled={isDisabled}
{...formControlProps}
style={labelPos === 'side' ? labelPosVerticalStyle : undefined}
>
{label !== '' && <FormLabel>{label}</FormLabel>}
<Input

View File

@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
marginBottom: 4,
},
})}
{...rest}

View File

@ -9,14 +9,14 @@ export type IAISelectDataType = {
tooltip?: string;
};
type IAISelectProps = Omit<SelectProps, 'label'> & {
export type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
};
const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, label, disabled, ...rest } = props;
const { tooltip, inputRef, label, disabled, required, ...rest } = props;
const styles = useMantineSelectStyles();
@ -25,7 +25,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
<Select
label={
label ? (
<FormControl isDisabled={disabled}>
<FormControl isRequired={required} isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined

View File

@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@ -48,6 +48,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
import IAICanvasUndoButton from './IAICanvasUndoButton';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
export const selector = createSelector(
[systemSelector, canvasSelector, isStagingSelector],
@ -79,6 +80,7 @@ const IAICanvasToolbar = () => {
const canvasBaseLayer = getCanvasBaseLayer();
const { t } = useTranslation();
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
const { openUploader } = useImageUploader();
@ -136,10 +138,10 @@ const IAICanvasToolbar = () => {
handleCopyImageToClipboard();
},
{
enabled: () => !isStaging,
enabled: () => !isStaging && isClipboardAPIAvailable,
preventDefault: true,
},
[canvasBaseLayer, isProcessing]
[canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
);
useHotkeys(
@ -189,6 +191,9 @@ const IAICanvasToolbar = () => {
};
const handleCopyImageToClipboard = () => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard());
};
@ -256,13 +261,15 @@ const IAICanvasToolbar = () => {
onClick={handleSaveToGallery}
isDisabled={isStaging}
/>
<IAIIconButton
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
icon={<FaCopy />}
onClick={handleCopyImageToClipboard}
isDisabled={isStaging}
/>
{isClipboardAPIAvailable && (
<IAIIconButton
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
icon={<FaCopy />}
onClick={handleCopyImageToClipboard}
isDisabled={isStaging}
/>
)}
<IAIIconButton
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}

View File

@ -1,7 +1,16 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es';
import { ButtonGroup, Flex, FlexProps, Link } from '@chakra-ui/react';
import {
ButtonGroup,
Flex,
FlexProps,
Link,
Menu,
MenuButton,
MenuItem,
MenuList,
} from '@chakra-ui/react';
// import { runESRGAN, runFacetool } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -20,6 +29,7 @@ import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/U
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
@ -48,6 +58,8 @@ import {
} from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { menuListMotionProps } from 'theme/components/menu';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector],
@ -120,6 +132,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toaster = useAppToaster();
const { t } = useTranslation();
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
@ -128,7 +143,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
500
);
const { currentData: image, isFetching } = useGetImageDTOQuery(
const { currentData: imageDTO, isFetching } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken
);
@ -142,15 +157,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => {
if (!image) {
if (!imageDTO) {
return;
}
if (image.image_url.startsWith('http')) {
return image.image_url;
if (imageDTO.image_url.startsWith('http')) {
return imageDTO.image_url;
}
return window.location.toString() + image.image_url;
return window.location.toString() + imageDTO.image_url;
};
const url = getImageUrl();
@ -174,7 +189,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
isClosable: true,
});
});
}, [toaster, t, image]);
}, [toaster, t, imageDTO]);
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata);
@ -192,31 +207,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
useHotkeys('s', handleUseSeed, [image]);
useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]);
useHotkeys('p', handleUsePrompt, [imageDTO]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
dispatch(initialImageSelected(imageDTO));
}, [dispatch, imageDTO]);
useHotkeys('shift+i', handleSendToImageToImage, [image]);
useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
const handleClickUpscale = useCallback(() => {
// selectedImage && dispatch(runESRGAN(selectedImage));
}, []);
const handleDelete = useCallback(() => {
if (!image) {
if (!imageDTO) {
return;
}
dispatch(imageToDeleteSelected(image));
}, [dispatch, image]);
dispatch(imageToDeleteSelected(imageDTO));
}, [dispatch, imageDTO]);
useHotkeys(
'Shift+U',
@ -236,7 +251,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
},
[
isUpscalingEnabled,
image,
imageDTO,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -268,7 +283,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
[
isFaceRestoreEnabled,
image,
imageDTO,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -283,10 +298,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const handleSendToCanvas = useCallback(() => {
if (!image) return;
if (!imageDTO) return;
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(setInitialCanvasImage(imageDTO));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
@ -299,12 +314,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
duration: 2500,
isClosable: true,
});
}, [image, dispatch, activeTabName, toaster, t]);
}, [imageDTO, dispatch, activeTabName, toaster, t]);
useHotkeys(
'i',
() => {
if (image) {
if (imageDTO) {
handleClickShowImageDetails();
} else {
toaster({
@ -315,13 +330,20 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[image, shouldShowImageDetails, toaster]
[imageDTO, shouldShowImageDetails, toaster]
);
const handleClickProgressImagesToggle = useCallback(() => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]);
const handleCopyImage = useCallback(() => {
if (!imageDTO) {
return;
}
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO]);
return (
<>
<Flex
@ -334,63 +356,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{...props}
>
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIPopover
triggerComponent={
<IAIIconButton
aria-label={`${t('parameters.sendTo')}...`}
tooltip={`${t('parameters.sendTo')}...`}
isDisabled={!image}
icon={<FaShareAlt />}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
rowGap: 2,
}}
>
<IAIButton
size="sm"
onClick={handleSendToImageToImage}
leftIcon={<FaShare />}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</IAIButton>
{isCanvasEnabled && (
<IAIButton
size="sm"
onClick={handleSendToCanvas}
leftIcon={<FaShare />}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</IAIButton>
)}
{/* <IAIButton
size="sm"
onClick={handleCopyImage}
leftIcon={<FaCopy />}
>
{t('parameters.copyImage')}
</IAIButton> */}
<IAIButton
size="sm"
onClick={handleCopyImageLink}
leftIcon={<FaCopy />}
>
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={image?.image_url} target="_blank">
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
</Link>
</Flex>
</IAIPopover>
<Menu>
<MenuButton
as={IAIIconButton}
aria-label={`${t('parameters.sendTo')}...`}
tooltip={`${t('parameters.sendTo')}...`}
isDisabled={!imageDTO}
icon={<FaShareAlt />}
/>
<MenuList motionProps={menuListMotionProps}>
{imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}
</MenuList>
</Menu>
</ButtonGroup>
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
@ -443,7 +420,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!image ||
!imageDTO ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
@ -474,7 +451,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isESRGANAvailable ||
!image ||
!imageDTO ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}

View File

@ -4,13 +4,14 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import { memo, useMemo } from 'react';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
type Props = {
imageDTO: ImageDTO;
imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children'];
};
@ -31,18 +32,32 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
const { selectionCount } = useAppSelector(selector);
const handleContextMenu = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.preventDefault();
}, []);
return (
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
{selectionCount === 1 ? (
<SingleSelectionMenuItems imageDTO={imageDTO} />
) : (
<MultipleSelectionMenuItems />
)}
</MenuList>
)}
menuButtonProps={{
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={() =>
imageDTO ? (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={handleContextMenu}
>
{selectionCount === 1 ? (
<SingleSelectionMenuItems imageDTO={imageDTO} />
) : (
<MultipleSelectionMenuItems />
)}
</MenuList>
) : null
}
>
{children}
</ContextMenu>

View File

@ -1,5 +1,4 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { MenuItem } from '@chakra-ui/react';
import { Link, MenuItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
@ -14,11 +13,21 @@ import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletio
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback, useContext, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaFolder, FaShare, FaTrash } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import {
FaAsterisk,
FaCopy,
FaDownload,
FaExternalLinkAlt,
FaFolder,
FaQuoteRight,
FaSeedling,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
@ -61,6 +70,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const metadata = currentData?.metadata;
const handleDelete = useCallback(() => {
@ -130,13 +142,27 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
dispatch(imagesAddedToBatch([imageDTO.image_name]));
}, [dispatch, imageDTO.image_name]);
const handleCopyImage = useCallback(() => {
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO.image_url]);
return (
<>
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}>
{t('common.openInNewTab')}
</MenuItem>
<Link href={imageDTO.image_url} target="_blank">
<MenuItem
icon={<FaExternalLinkAlt />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
</Link>
{isClipboardAPIAvailable && (
<MenuItem icon={<FaCopy />} onClickCapture={handleCopyImage}>
{t('parameters.copyImage')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
icon={<FaQuoteRight />}
onClickCapture={handleRecallPrompt}
isDisabled={
metadata?.positive_prompt === undefined &&
@ -147,14 +173,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
icon={<FaSeedling />}
onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
icon={<FaAsterisk />}
onClickCapture={handleUseAllParameters}
isDisabled={!metadata}
>
@ -193,6 +219,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
Remove from Board
</MenuItem>
)}
<Link download={true} href={imageDTO.image_url} target="_blank">
<MenuItem icon={<FaDownload />} w="100%">
{t('parameters.downloadImage')}
</MenuItem>
</Link>
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}

View File

@ -16,14 +16,13 @@ import {
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
IMAGE_LIMIT,
selectImagesAll,
} from 'features/gallery//store/gallerySlice';
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
import { VirtuosoGrid } from 'react-virtuoso';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer';
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
const selector = createSelector(
[stateSelector, selectFilteredImages],
@ -180,7 +179,6 @@ const GalleryImageGrid = () => {
</Box>
);
}
console.log({ selectedBoardId });
if (status !== 'rejected') {
return (

View File

@ -1,17 +1,36 @@
import { ButtonGroup } from '@chakra-ui/react';
import { ButtonGroup, Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react';
import { FaCode, FaExpand, FaMinus, FaPlus } from 'react-icons/fa';
import {
FaCode,
FaExpand,
FaMinus,
FaPlus,
FaInfo,
FaMapMarkerAlt,
} from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
import { shouldShowGraphOverlayChanged } from '../store/nodesSlice';
import { useTranslation } from 'react-i18next';
import {
shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
} from '../store/nodesSlice';
const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
const shouldShowGraphOverlay = useAppSelector(
(state) => state.nodes.shouldShowGraphOverlay
);
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
const shouldShowMinimapPanel = useAppSelector(
(state) => state.nodes.shouldShowMinimapPanel
);
const handleClickedZoomIn = useCallback(() => {
zoomIn();
@ -29,29 +48,64 @@ const ViewportControls = () => {
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
}, [shouldShowGraphOverlay, dispatch]);
const handleClickedToggleFieldTypeLegend = useCallback(() => {
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
}, [shouldShowFieldTypeLegend, dispatch]);
const handleClickedToggleMiniMapPanel = useCallback(() => {
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
}, [shouldShowMinimapPanel, dispatch]);
return (
<ButtonGroup isAttached orientation="vertical">
<IAIIconButton
onClick={handleClickedZoomIn}
aria-label="Zoom In"
icon={<FaPlus />}
/>
<IAIIconButton
onClick={handleClickedZoomOut}
aria-label="Zoom Out"
icon={<FaMinus />}
/>
<IAIIconButton
onClick={handleClickedFitView}
aria-label="Fit to Viewport"
icon={<FaExpand />}
/>
<IAIIconButton
isChecked={shouldShowGraphOverlay}
onClick={handleClickedToggleGraphOverlay}
aria-label="Show/Hide Graph"
icon={<FaCode />}
/>
<Tooltip label={t('nodes.zoomInNodes')}>
<IAIIconButton onClick={handleClickedZoomIn} icon={<FaPlus />} />
</Tooltip>
<Tooltip label={t('nodes.zoomOutNodes')}>
<IAIIconButton onClick={handleClickedZoomOut} icon={<FaMinus />} />
</Tooltip>
<Tooltip label={t('nodes.fitViewportNodes')}>
<IAIIconButton onClick={handleClickedFitView} icon={<FaExpand />} />
</Tooltip>
<Tooltip
label={
shouldShowGraphOverlay
? t('nodes.hideGraphNodes')
: t('nodes.showGraphNodes')
}
>
<IAIIconButton
isChecked={shouldShowGraphOverlay}
onClick={handleClickedToggleGraphOverlay}
icon={<FaCode />}
/>
</Tooltip>
<Tooltip
label={
shouldShowFieldTypeLegend
? t('nodes.hideLegendNodes')
: t('nodes.showLegendNodes')
}
>
<IAIIconButton
isChecked={shouldShowFieldTypeLegend}
onClick={handleClickedToggleFieldTypeLegend}
icon={<FaInfo />}
/>
</Tooltip>
<Tooltip
label={
shouldShowMinimapPanel
? t('nodes.hideMinimapnodes')
: t('nodes.showMinimapnodes')
}
>
<IAIIconButton
isChecked={shouldShowMinimapPanel}
onClick={handleClickedToggleMiniMapPanel}
icon={<FaMapMarkerAlt />}
/>
</Tooltip>
</ButtonGroup>
);
};

View File

@ -1,3 +1,5 @@
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useColorModeValue } from '@chakra-ui/react';
import { memo } from 'react';
import { MiniMap } from 'reactflow';
@ -12,6 +14,10 @@ const MinimapPanel = () => {
}
);
const shouldShowMinimapPanel = useAppSelector(
(state: RootState) => state.nodes.shouldShowMinimapPanel
);
const nodeColor = useColorModeValue(
'var(--invokeai-colors-accent-300)',
'var(--invokeai-colors-accent-700)'
@ -23,15 +29,19 @@ const MinimapPanel = () => {
);
return (
<MiniMap
nodeStrokeWidth={3}
pannable
zoomable
nodeBorderRadius={30}
style={miniMapStyle}
nodeColor={nodeColor}
maskColor={maskColor}
/>
<>
{shouldShowMinimapPanel && (
<MiniMap
nodeStrokeWidth={3}
pannable
zoomable
nodeBorderRadius={30}
style={miniMapStyle}
nodeColor={nodeColor}
maskColor={maskColor}
/>
)}
</>
);
};

View File

@ -9,10 +9,13 @@ const TopRightPanel = () => {
const shouldShowGraphOverlay = useAppSelector(
(state: RootState) => state.nodes.shouldShowGraphOverlay
);
const shouldShowFieldTypeLegend = useAppSelector(
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
);
return (
<Panel position="top-right">
<FieldTypeLegend />
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
{shouldShowGraphOverlay && <NodeGraphOverlay />}
</Panel>
);

View File

@ -32,6 +32,8 @@ export type NodesState = {
invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
shouldShowGraphOverlay: boolean;
shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean;
editorInstance: ReactFlowInstance | undefined;
};
@ -42,6 +44,8 @@ export const initialNodesState: NodesState = {
invocationTemplates: {},
connectionStartParams: null,
shouldShowGraphOverlay: false,
shouldShowFieldTypeLegend: false,
shouldShowMinimapPanel: true,
editorInstance: undefined,
};
@ -125,6 +129,15 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload;
},
shouldShowFieldTypeLegendChanged: (
state,
action: PayloadAction<boolean>
) => {
state.shouldShowFieldTypeLegend = action.payload;
},
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
},
nodeTemplatesBuilt: (
state,
action: PayloadAction<Record<string, InvocationTemplate>>
@ -161,6 +174,8 @@ export const {
connectionStarted,
connectionEnded,
shouldShowGraphOverlayChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
nodeTemplatesBuilt,
nodeEditorReset,
imageCollectionFieldValueChanged,

View File

@ -14,6 +14,14 @@ export const clipSkipMap = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
sdxl: {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
'sdxl-refiner': {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
};
export default function ParamClipSkip() {

View File

@ -1,4 +1,6 @@
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
};

View File

@ -126,7 +126,7 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']);
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
export type BaseModelParam = z.infer<typeof zBaseModel>;

View File

@ -1,11 +1,10 @@
import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next';
import { t } from 'i18next';
import { LogLevelName } from 'roarr';
import { imageUploaded } from 'services/api/thunks/image';
import {
@ -44,8 +43,6 @@ export interface SystemState {
isCancelable: boolean;
enableImageDebugging: boolean;
toastQueue: UseToastOptions[];
searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null;
/**
* The current progress image
*/
@ -79,7 +76,7 @@ export interface SystemState {
*/
consoleLogLevel: InvokeLogLevel;
shouldLogToConsole: boolean;
statusTranslationKey: TFuncKey;
statusTranslationKey: any;
/**
* When a session is canceled, its ID is stored here until a new session is created.
*/
@ -106,8 +103,6 @@ export const initialSystemState: SystemState = {
isCancelable: true,
enableImageDebugging: false,
toastQueue: [],
searchFolder: null,
foundModels: null,
progressImage: null,
shouldAntialiasProgressImage: false,
sessionId: null,
@ -132,7 +127,7 @@ export const systemSlice = createSlice({
setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload;
},
setCurrentStatus: (state, action: PayloadAction<TFuncKey>) => {
setCurrentStatus: (state, action: any) => {
state.statusTranslationKey = action.payload;
},
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
@ -153,15 +148,6 @@ export const systemSlice = createSlice({
clearToastQueue: (state) => {
state.toastQueue = [];
},
setSearchFolder: (state, action: PayloadAction<string | null>) => {
state.searchFolder = action.payload;
},
setFoundModels: (
state,
action: PayloadAction<InvokeAI.FoundModel[] | null>
) => {
state.foundModels = action.payload;
},
/**
* A cancel was scheduled
*/
@ -426,8 +412,6 @@ export const {
setEnableImageDebugging,
addToast,
clearToastQueue,
setSearchFolder,
setFoundModels,
cancelScheduled,
scheduledCancelAborted,
cancelTypeChanged,

View File

@ -1,11 +1,11 @@
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
import i18n from 'i18n';
import { ReactNode, memo } from 'react';
import AddModelsPanel from './subpanels/AddModelsPanel';
import ImportModelsPanel from './subpanels/ImportModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel';
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels';
type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels';
type ModelManagerTabInfo = {
id: ModelManagerTabName;
@ -20,9 +20,9 @@ const tabs: ModelManagerTabInfo[] = [
content: <ModelManagerPanel />,
},
{
id: 'addModels',
label: i18n.t('modelManager.addModel'),
content: <AddModelsPanel />,
id: 'importModels',
label: i18n.t('modelManager.importModels'),
content: <ImportModelsPanel />,
},
{
id: 'mergeModels',
@ -46,7 +46,7 @@ const ModelManagerTab = () => {
</Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
<TabPanels sx={{ w: 'full', h: 'full' }}>
{tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content}

View File

@ -0,0 +1,29 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
type ModelManagerState = {
searchFolder: string | null;
advancedAddScanModel: string | null;
};
const initialModelManagerState: ModelManagerState = {
searchFolder: null,
advancedAddScanModel: null,
};
export const modelManagerSlice = createSlice({
name: 'modelmanager',
initialState: initialModelManagerState,
reducers: {
setSearchFolder: (state, action: PayloadAction<string | null>) => {
state.searchFolder = action.payload;
},
setAdvancedAddScanModel: (state, action: PayloadAction<string | null>) => {
state.advancedAddScanModel = action.payload;
},
},
});
export const { setSearchFolder, setAdvancedAddScanModel } =
modelManagerSlice.actions;
export default modelManagerSlice.reducer;

View File

@ -0,0 +1,3 @@
import { RootState } from 'app/store/store';
export const modelmanagerSelector = (state: RootState) => state.modelmanager;

View File

@ -1,43 +0,0 @@
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
export default function AddModelsPanel() {
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<Flex columnGap={4}>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
isChecked={addNewModelUIOption == 'ckpt'}
>
{t('modelManager.addCheckpointModel')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
isChecked={addNewModelUIOption == 'diffusers'}
>
{t('modelManager.addDiffuserModel')}
</IAIButton>
</Flex>
<Divider />
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
</Flex>
);
}

View File

@ -1,337 +0,0 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import React from 'react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import SearchModels from './SearchModels';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function AddCheckpointModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeModelConfigProps = {
name: '',
description: '',
config: 'configs/stable-diffusion/v1-inference.yaml',
weights: '',
vae: '',
width: 512,
height: 512,
format: 'ckpt',
default: false,
};
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(addNewModel(values));
dispatch(setAddNewModelUIOption(null));
};
const [addManually, setAddmanually] = React.useState<boolean>(false);
return (
<VStack gap={2} alignItems="flex-start">
<Flex columnGap={4}>
<IAISimpleCheckbox
isChecked={!addManually}
label={t('modelManager.scanForModels')}
onChange={() => setAddmanually(!addManually)}
/>
<IAISimpleCheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
</Flex>
{addManually ? (
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')}
</Text>
{/* Name */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="full"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Description */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Config */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Weights */}
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* VAE */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<HStack width="100%">
{/* Width */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Height */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
) : (
<SearchModels />
)}
</VStack>
);
}

View File

@ -1,259 +0,0 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Text,
VStack,
} from '@chakra-ui/react';
import { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeDiffusersModelConfigProps = {
name: '',
description: '',
repo_id: '',
path: '',
format: 'diffusers',
default: false,
vae: {
repo_id: '',
path: '',
},
};
const addModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToAdd = values;
if (values.path === '') delete diffusersModelToAdd.path;
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
dispatch(addNewModel(diffusersModelToAdd));
dispatch(setAddNewModelUIOption(null));
};
return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
isRequired
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
isRequired
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersModelLocationDesc')}
</Text>
{/* Path */}
<FormControl isInvalid={!!errors.path && touched.path}>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field as={IAIInput} id="path" name="path" type="text" />
{!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
/>
{!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.repoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersVAELocationDesc')}
</Text>
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeRepoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
);
}

View File

@ -0,0 +1,48 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels';
export default function AddModels() {
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>(
'simple'
);
return (
<Flex
flexDirection="column"
width="100%"
overflow="scroll"
maxHeight={window.innerHeight - 250}
gap={4}
>
<ButtonGroup isAttached>
<IAIButton
size="sm"
isChecked={addModelMode == 'simple'}
onClick={() => setAddModelMode('simple')}
>
Simple
</IAIButton>
<IAIButton
size="sm"
isChecked={addModelMode == 'advanced'}
onClick={() => setAddModelMode('advanced')}
>
Advanced
</IAIButton>
</ButtonGroup>
<Flex
sx={{
p: 4,
borderRadius: 4,
background: 'base.200',
_dark: { background: 'base.800' },
}}
>
{addModelMode === 'simple' && <SimpleAddModels />}
{addModelMode === 'advanced' && <AdvancedAddModels />}
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,143 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { addToast } from 'features/system/store/systemSlice';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type AdvancedAddCheckpointProps = {
model_path?: string;
};
export default function AdvancedAddCheckpoint(
props: AdvancedAddCheckpointProps
) {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
initialValues: {
model_name: model_path
? model_path.split('\\').splice(-1)[0].split('.')[0]
: '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'checkpoint',
error: undefined,
vae: '',
variant: 'normal',
config: 'configs\\stable-diffusion\\v1-inference.yaml',
},
});
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const advancedAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
advancedAddCheckpointForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={advancedAddCheckpointForm.onSubmit((v) =>
advancedAddCheckpointFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
label="Model Name"
required
{...advancedAddCheckpointForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...advancedAddCheckpointForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
label="Model Location"
required
{...advancedAddCheckpointForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...advancedAddCheckpointForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...advancedAddCheckpointForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...advancedAddCheckpointForm.getInputProps('variant')}
/>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
required
width="100%"
{...advancedAddCheckpointForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label="Custom Config File Location"
{...advancedAddCheckpointForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</Flex>
</form>
);
}

View File

@ -0,0 +1,113 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { addToast } from 'features/system/store/systemSlice';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type AdvancedAddDiffusersProps = {
model_path?: string;
};
export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const [addMainModel] = useAddMainModelsMutation();
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
initialValues: {
model_name: model_path ? model_path.split('\\').splice(-1)[0] : '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'diffusers',
error: undefined,
vae: '',
variant: 'normal',
},
});
const advancedAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
advancedAddDiffusersForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={advancedAddDiffusersForm.onSubmit((v) =>
advancedAddDiffusersFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
required
label="Model Name"
{...advancedAddDiffusersForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...advancedAddDiffusersForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
required
label="Model Location"
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
{...advancedAddDiffusersForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...advancedAddDiffusersForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...advancedAddDiffusersForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...advancedAddDiffusersForm.getInputProps('variant')}
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</form>
);
}

View File

@ -0,0 +1,46 @@
import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useState } from 'react';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
export const advancedAddModeData: SelectItem[] = [
{ label: 'Diffusers', value: 'diffusers' },
{ label: 'Checkpoint / Safetensors', value: 'checkpoint' },
];
export type ManualAddMode = 'diffusers' | 'checkpoint';
export default function AdvancedAddModels() {
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
return (
<Flex flexDirection="column" gap={4} width="100%">
<IAIMantineSelect
label="Model Type"
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) return;
setAdvancedAddMode(v as ManualAddMode);
}}
/>
<Flex
sx={{
p: 4,
borderRadius: 4,
bg: 'base.300',
_dark: {
bg: 'base.850',
},
}}
>
{advancedAddMode === 'diffusers' && <AdvancedAddDiffusers />}
{advancedAddMode === 'checkpoint' && <AdvancedAddCheckpoint />}
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,253 @@
import { Flex, Text } from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIScrollArea from 'common/components/IAIScrollArea';
import { addToast } from 'features/system/store/systemSlice';
import { difference, forEach, intersection, map, values } from 'lodash-es';
import { ChangeEvent, MouseEvent, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
SearchFolderResponse,
useGetMainModelsQuery,
useGetModelsInFolderQuery,
useImportMainModelsMutation,
} from 'services/api/endpoints/models';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
export default function FoundModelsList() {
const searchFolder = useAppSelector(
(state: RootState) => state.modelmanager.searchFolder
);
const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery();
// Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } =
useGetModelsInFolderQuery(
{
search_path: searchFolder ? searchFolder : '',
},
{
selectFromResult: ({ data }) => {
const installedModelValues = values(installedModels?.entities);
const installedModelPaths = map(installedModelValues, 'path');
// Only select models those that aren't already installed to Invoke
const notInstalledModels = difference(data, installedModelPaths);
const alreadyInstalled = intersection(data, installedModelPaths);
return {
foundModels: data,
alreadyInstalled: foundModelsFilter(alreadyInstalled, nameFilter),
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
};
},
}
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const quickAddHandler = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
const model_name = e.currentTarget.id.split('\\').splice(-1)[0];
importMainModel({
body: {
location: e.currentTarget.id,
},
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Added Model: ${model_name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Faile To Add Model',
status: 'error',
})
)
);
}
});
},
[dispatch, importMainModel]
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
const renderModels = ({
models,
showActions = true,
}: {
models: string[];
showActions?: boolean;
}) => {
return models.map((model) => {
return (
<Flex
sx={{
p: 4,
gap: 4,
alignItems: 'center',
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
key={model}
>
<Flex w="100%" sx={{ flexDirection: 'column', minW: '25%' }}>
<Text sx={{ fontWeight: 600 }}>
{model.split('\\').slice(-1)[0]}
</Text>
<Text
sx={{
fontSize: 'sm',
color: 'base.600',
_dark: {
color: 'base.400',
},
}}
>
{model}
</Text>
</Flex>
{showActions ? (
<Flex gap={2}>
<IAIButton
id={model}
onClick={quickAddHandler}
isLoading={isLoading}
>
Quick Add
</IAIButton>
<IAIButton
onClick={() => dispatch(setAdvancedAddScanModel(model))}
isLoading={isLoading}
>
Advanced
</IAIButton>
</Flex>
) : (
<Text
sx={{
fontWeight: 600,
p: 2,
borderRadius: 4,
color: 'accent.50',
bg: 'accent.400',
_dark: {
color: 'accent.100',
bg: 'accent.600',
},
}}
>
Installed
</Text>
)}
</Flex>
);
});
};
const renderFoundModels = () => {
if (!searchFolder) return;
if (!foundModels || foundModels.length === 0) {
return (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
height: 96,
userSelect: 'none',
bg: 'base.200',
_dark: {
bg: 'base.900',
},
}}
>
<Text variant="subtext">No Models Found</Text>
</Flex>
);
}
return (
<Flex
sx={{
flexDirection: 'column',
gap: 2,
w: '100%',
minW: '50%',
}}
>
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
labelPos="side"
/>
<Flex p={2} gap={2}>
<Text sx={{ fontWeight: 600 }}>
Models Found: {foundModels.length}
</Text>
<Text
sx={{
fontWeight: 600,
color: 'accent.500',
_dark: {
color: 'accent.200',
},
}}
>
Not Installed: {filteredModels.length}
</Text>
</Flex>
<IAIScrollArea offsetScrollbars>
<Flex gap={2} flexDirection="column">
{renderModels({ models: filteredModels })}
{renderModels({ models: alreadyInstalled, showActions: false })}
</Flex>
</IAIScrollArea>
</Flex>
);
};
return renderFoundModels();
}
const foundModelsFilter = (
data: SearchFolderResponse | undefined,
nameFilter: string
) => {
const filteredModels: SearchFolderResponse = [];
forEach(data, (model) => {
if (!model) {
return;
}
if (model.includes(nameFilter)) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -0,0 +1,97 @@
import { Box, Flex, Text } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { motion } from 'framer-motion';
import { useEffect, useState } from 'react';
import { FaTimes } from 'react-icons/fa';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
import { ManualAddMode, advancedAddModeData } from './AdvancedAddModels';
export default function ScanAdvancedAddModels() {
const advancedAddScanModel = useAppSelector(
(state: RootState) => state.modelmanager.advancedAddScanModel
);
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
const [isCheckpoint, setIsCheckpoint] = useState<boolean>(true);
useEffect(() => {
advancedAddScanModel &&
['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) =>
advancedAddScanModel.endsWith(ext)
)
? setAdvancedAddMode('checkpoint')
: setAdvancedAddMode('diffusers');
}, [advancedAddScanModel, setAdvancedAddMode, isCheckpoint]);
const dispatch = useAppDispatch();
return (
advancedAddScanModel && (
<Box
as={motion.div}
initial={{ x: -100, opacity: 0 }}
animate={{ x: 0, opacity: 1, transition: { duration: 0.2 } }}
sx={{
display: 'flex',
flexDirection: 'column',
minWidth: '40%',
maxHeight: window.innerHeight - 300,
overflow: 'scroll',
p: 4,
gap: 4,
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
>
<Flex justifyContent="space-between" alignItems="center">
<Text size="xl" fontWeight={600}>
{isCheckpoint || advancedAddMode === 'checkpoint'
? 'Add Checkpoint Model'
: 'Add Diffusers Model'}
</Text>
<IAIIconButton
icon={<FaTimes />}
aria-label="Close Advanced"
onClick={() => dispatch(setAdvancedAddScanModel(null))}
size="sm"
/>
</Flex>
<IAIMantineSelect
label="Model Type"
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) return;
setAdvancedAddMode(v as ManualAddMode);
if (v === 'checkpoint') {
setIsCheckpoint(true);
} else {
setIsCheckpoint(false);
}
}}
/>
{isCheckpoint ? (
<AdvancedAddCheckpoint
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
) : (
<AdvancedAddDiffusers
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
)}
</Box>
)
);
}

View File

@ -0,0 +1,25 @@
import { Flex } from '@chakra-ui/react';
import FoundModelsList from './FoundModelsList';
import ScanAdvancedAddModels from './ScanAdvancedAddModels';
import SearchFolderForm from './SearchFolderForm';
export default function ScanModels() {
return (
<Flex flexDirection="column" w="100%" gap={4}>
<SearchFolderForm />
<Flex gap={4}>
<Flex
sx={{
maxHeight: window.innerHeight - 300,
overflow: 'scroll',
gap: 4,
w: '100%',
}}
>
<FoundModelsList />
</Flex>
<ScanAdvancedAddModels />
</Flex>
</Flex>
);
}

View File

@ -0,0 +1,139 @@
import { Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIInput from 'common/components/IAIInput';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSearch, FaSync, FaTrash } from 'react-icons/fa';
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
import {
setAdvancedAddScanModel,
setSearchFolder,
} from '../../store/modelManagerSlice';
type SearchFolderForm = {
folder: string;
};
function SearchFolderForm() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const searchFolder = useAppSelector(
(state: RootState) => state.modelmanager.searchFolder
);
const { refetch: refetchFoundModels } = useGetModelsInFolderQuery({
search_path: searchFolder ? searchFolder : '',
});
const searchFolderForm = useForm<SearchFolderForm>({
initialValues: {
folder: '',
},
});
const searchFolderFormSubmitHandler = useCallback(
(values: SearchFolderForm) => {
dispatch(setSearchFolder(values.folder));
},
[dispatch]
);
const scanAgainHandler = () => {
refetchFoundModels();
};
return (
<form
onSubmit={searchFolderForm.onSubmit((values) =>
searchFolderFormSubmitHandler(values)
)}
style={{ width: '100%' }}
>
<Flex
sx={{
w: '100%',
gap: 2,
borderRadius: 4,
alignItems: 'center',
}}
>
<Flex w="100%" alignItems="center" gap={4} minH={12}>
<Text
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'base.700',
minW: 'max-content',
_dark: { color: 'base.300' },
}}
>
Folder
</Text>
{!searchFolder ? (
<IAIInput
w="100%"
size="md"
{...searchFolderForm.getInputProps('folder')}
/>
) : (
<Flex
sx={{
w: '100%',
p: 2,
px: 4,
bg: 'base.300',
borderRadius: 4,
fontSize: 'sm',
fontWeight: 'bold',
_dark: { bg: 'base.700' },
}}
>
{searchFolder}
</Flex>
)}
</Flex>
<Flex gap={2}>
{!searchFolder ? (
<IAIIconButton
aria-label={t('modelManager.findModels')}
tooltip={t('modelManager.findModels')}
icon={<FaSearch />}
fontSize={18}
size="sm"
type="submit"
/>
) : (
<IAIIconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<FaSync />}
onClick={scanAgainHandler}
fontSize={18}
size="sm"
/>
)}
<IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
size="sm"
onClick={() => {
dispatch(setSearchFolder(null));
dispatch(setAdvancedAddScanModel(null));
}}
isDisabled={!searchFolder}
colorScheme="red"
/>
</Flex>
</Flex>
</form>
);
}
export default memo(SearchFolderForm);

View File

@ -0,0 +1,108 @@
import { Flex } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { SelectItem } from '@mantine/core';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { addToast } from 'features/system/store/systemSlice';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
const predictionSelectData: SelectItem[] = [
{ label: 'None', value: 'none' },
{ label: 'v_prediction', value: 'v_prediction' },
{ label: 'epsilon', value: 'epsilon' },
{ label: 'sample', value: 'sample' },
];
type ExtendedImportModelConfig = {
location: string;
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
};
export default function SimpleAddModels() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm<ExtendedImportModelConfig>({
initialValues: {
location: '',
prediction_type: undefined,
},
});
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
const importModelResponseBody = {
location: values.location,
prediction_type:
values.prediction_type === 'none' ? undefined : values.prediction_type,
};
importMainModel({ body: importModelResponseBody })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: 'Model Added',
status: 'success',
})
)
);
addModelForm.reset();
})
.catch((error) => {
if (error) {
console.log(error);
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}
style={{ width: '100%' }}
>
<Flex flexDirection="column" width="100%" gap={4}>
<IAIMantineTextInput
label="Model Location"
placeholder="Provide a path to a local Diffusers model, local checkpoint / safetensors model or a HuggingFace Repo ID"
w="100%"
{...addModelForm.getInputProps('location')}
/>
<IAIMantineSelect
label="Prediction Type (for Stable Diffusion 2.x Models only)"
data={predictionSelectData}
defaultValue="none"
{...addModelForm.getInputProps('prediction_type')}
/>
<IAIButton
type="submit"
isLoading={isLoading}
isDisabled={isLoading || isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</form>
);
}

View File

@ -0,0 +1,39 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import AddModels from './AddModelsPanel/AddModels';
import ScanModels from './AddModelsPanel/ScanModels';
type AddModelTabs = 'add' | 'scan';
export default function ImportModelsPanel() {
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setAddModelTab('add')}
isChecked={addModelTab == 'add'}
size="sm"
width="100%"
>
{t('modelManager.addModel')}
</IAIButton>
<IAIButton
onClick={() => setAddModelTab('scan')}
isChecked={addModelTab == 'scan'}
size="sm"
width="100%"
>
{t('modelManager.scanForModels')}
</IAIButton>
</ButtonGroup>
{addModelTab == 'add' && <AddModels />}
{addModelTab == 'scan' && <ScanModels />}
</Flex>
);
}

View File

@ -1,11 +1,4 @@
import {
Flex,
Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -23,7 +16,6 @@ import {
useMergeMainModelsMutation,
} from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
@ -38,7 +30,6 @@ type MergeInterpolationMethods =
export default function MergeModelsPanel() {
const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery();
@ -67,10 +58,10 @@ export default function MergeModelsPanel() {
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[0]
);
const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
@ -98,9 +89,9 @@ export default function MergeModelsPanel() {
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
);
const modelThreeList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelTwo);
const handleBaseModelChange = (v: string) => {
setBaseModel(v as BaseModelType);
@ -125,9 +116,9 @@ export default function MergeModelsPanel() {
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
// model_merge_save_path:
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
merge_dest_directory:
modelMergeSaveLocType === 'root' ? undefined : modelMergeCustomSaveLoc,
};
mergeModels({
@ -227,7 +218,10 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: mode('base.100', 'base.800')(colorMode),
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
>
<IAISlider
@ -252,7 +246,10 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: mode('base.100', 'base.800')(colorMode),
bg: 'base.200',
_dark: {
bg: 'base.800',
},
}}
>
<Text fontWeight={500} fontSize="sm" variant="subtext">
@ -288,13 +285,16 @@ export default function MergeModelsPanel() {
</RadioGroup>
</Flex>
{/* <Flex
<Flex
sx={{
flexDirection: 'column',
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
bg: 'base.200',
_dark: {
bg: 'base.900',
},
}}
>
<Flex columnGap={4}>
@ -324,7 +324,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex> */}
</Flex>
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}

View File

@ -1,33 +1,26 @@
import { Divider, Flex, Text } from '@chakra-ui/react';
import { Badge, Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
CheckpointModelConfigEntity,
useGetCheckpointConfigsQuery,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
import ModelConvert from './ModelConvert';
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
type CheckpointModelEditProps = {
model: CheckpointModelConfigEntity;
};
@ -38,6 +31,15 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
useEffect(() => {
if (!availableCheckpointConfigs?.includes(model.config)) {
setUseCustomConfig(true);
}
}, [availableCheckpointConfigs, model.config]);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -80,7 +82,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)
);
})
.catch((error) => {
.catch((_) => {
checkpointEditForm.reset();
dispatch(
addToast(
@ -113,7 +115,20 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<ModelConvert model={model} />
{!['sdxl', 'sdxl-refiner'].includes(model.base_model) ? (
<ModelConvert model={model} />
) : (
<Badge
sx={{
p: 2,
borderRadius: 4,
bg: 'error.200',
_dark: { bg: 'error.400' },
}}
>
Conversion Not Supported
</Badge>
)}
</Flex>
<Divider />
@ -128,21 +143,24 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput
label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')}
/>
<IAIMantineSelect
label={t('modelManager.baseModel')}
data={baseModelSelectData}
<BaseModelSelect
required
{...checkpointEditForm.getInputProps('base_model')}
/>
<IAIMantineSelect
label={t('modelManager.variant')}
data={variantSelectData}
<ModelVariantSelect
required
{...checkpointEditForm.getInputProps('variant')}
/>
<IAIMantineTextInput
required
label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')}
/>
@ -150,10 +168,27 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')}
/>
<IAIMantineTextInput
label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
/>
<Flex flexDirection="column" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
required
{...checkpointEditForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
</Flex>
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}

View File

@ -4,7 +4,6 @@ import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
@ -15,22 +14,13 @@ import {
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
type DiffusersModelEditProps = {
model: DiffusersModelConfigEntity;
};
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
@ -65,6 +55,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
model_name: model.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
@ -78,7 +69,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)
);
})
.catch((error) => {
.catch((_) => {
diffusersEditForm.reset();
dispatch(
addToast(
@ -118,21 +109,24 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput
label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')}
/>
<IAIMantineSelect
label={t('modelManager.baseModel')}
data={baseModelSelectData}
<BaseModelSelect
required
{...diffusersEditForm.getInputProps('base_model')}
/>
<IAIMantineSelect
label={t('modelManager.variant')}
data={variantSelectData}
<ModelVariantSelect
required
{...diffusersEditForm.getInputProps('variant')}
/>
<IAIMantineTextInput
required
label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')}
/>

View File

@ -1,9 +1,18 @@
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions';
import {
Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
// import { convertToDiffusers } from 'app/socketio/actions';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
@ -15,6 +24,8 @@ interface ModelConvertProps {
model: CheckpointModelConfig;
}
type SaveLocation = 'InvokeAIRoot' | 'Custom';
export default function ModelConvert(props: ModelConvertProps) {
const { model } = props;
@ -23,22 +34,51 @@ export default function ModelConvert(props: ModelConvertProps) {
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same');
const [saveLocation, setSaveLocation] =
useState<SaveLocation>('InvokeAIRoot');
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
useEffect(() => {
setSaveLocation('same');
setSaveLocation('InvokeAIRoot');
}, [model]);
const modelConvertCancelHandler = () => {
setSaveLocation('same');
setSaveLocation('InvokeAIRoot');
};
const modelConvertHandler = () => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
params: {
convert_dest_directory:
saveLocation === 'Custom' ? customSaveLocation : undefined,
},
};
if (saveLocation === 'Custom' && customSaveLocation === '') {
dispatch(
addToast(
makeToast({
title: t('modelManager.noCustomLocationProvided'),
status: 'error',
})
)
);
return;
}
dispatch(
addToast(
makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${
model.model_name
}`,
status: 'success',
})
)
);
convertModel(responseBody)
.unwrap()
.then((_) => {
@ -94,35 +134,30 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex>
{/* <Flex flexDir="column" gap={4}>
<Flex flexDir="column" gap={2}>
<Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')}
</Text>
<RadioGroup value={saveLocation} onChange={(v) => setSaveLocation(v)}>
<RadioGroup
value={saveLocation}
onChange={(v) => setSaveLocation(v as SaveLocation)}
>
<Flex gap={4}>
<Radio value="same">
<Tooltip label="Save converted model in the same folder">
{t('modelManager.sameFolder')}
</Tooltip>
</Radio>
<Radio value="root">
<Radio value="InvokeAIRoot">
<Tooltip label="Save converted model in the InvokeAI root folder">
{t('modelManager.invokeRoot')}
</Tooltip>
</Radio>
<Radio value="custom">
<Radio value="Custom">
<Tooltip label="Save converted model in a custom folder">
{t('modelManager.custom')}
</Tooltip>
</Radio>
</Flex>
</RadioGroup>
</Flex> */}
{/* {saveLocation === 'custom' && (
</Flex>
{saveLocation === 'Custom' && (
<Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')}
@ -130,13 +165,13 @@ export default function ModelConvert(props: ModelConvertProps) {
<IAIInput
value={customSaveLocation}
onChange={(e) => {
if (e.target.value !== '')
setCustomSaveLocation(e.target.value);
setCustomSaveLocation(e.target.value);
}}
width="full"
/>
</Flex>
)} */}
)}
</Flex>
</IAIAlertDialog>
);
}

View File

@ -44,10 +44,6 @@ const ModelList = (props: ModelListProps) => {
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
/>
<Flex
flexDirection="column"
gap={4}
@ -79,6 +75,12 @@ const ModelList = (props: ModelListProps) => {
</IAIButton>
</ButtonGroup>
<IAIInput
onChange={handleSearchFilter}
label={t('modelManager.search')}
labelPos="side"
/>
{['all', 'diffusers'].includes(modelFormatFilter) &&
filteredDiffusersModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>

View File

@ -1,10 +1,12 @@
import { DeleteIcon } from '@chakra-ui/icons';
import { Flex, Text, Tooltip } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
@ -21,6 +23,7 @@ type ModelListItemProps = {
export default function ModelListItem(props: ModelListItemProps) {
const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [deleteMainModel] = useDeleteMainModelsMutation();
const { model, isSelected, setSelectedModelId } = props;
@ -30,9 +33,34 @@ export default function ModelListItem(props: ModelListItemProps) {
}, [model.id, setSelectedModelId]);
const handleModelDelete = useCallback(() => {
deleteMainModel(model);
deleteMainModel(model)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleted')}: ${model.model_name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleteFailed')}: ${
model.model_name
}`,
status: 'success',
})
)
);
}
});
setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId]);
}, [deleteMainModel, model, setSelectedModelId, dispatch, t]);
return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>

View File

@ -0,0 +1,28 @@
import IAIMantineSelect, {
IAISelectDataType,
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useTranslation } from 'react-i18next';
const baseModelSelectData: IAISelectDataType[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
type BaseModelSelectProps = Omit<IAISelectProps, 'data'>;
export default function BaseModelSelect(props: BaseModelSelectProps) {
const { ...rest } = props;
const { t } = useTranslation();
return (
<IAIMantineSelect
label={t('modelManager.baseModel')}
data={baseModelSelectData}
{...rest}
/>
);
}

View File

@ -0,0 +1,22 @@
import IAIMantineSelect, {
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
type CheckpointConfigSelectProps = Omit<IAISelectProps, 'data'>;
export default function CheckpointConfigsSelect(
props: CheckpointConfigSelectProps
) {
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const { ...rest } = props;
return (
<IAIMantineSelect
label="Config File"
placeholder="Select A Config File"
data={availableCheckpointConfigs ? availableCheckpointConfigs : []}
{...rest}
/>
);
}

View File

@ -0,0 +1,26 @@
import IAIMantineSelect, {
IAISelectDataType,
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { useTranslation } from 'react-i18next';
const variantSelectData: IAISelectDataType[] = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
type VariantSelectProps = Omit<IAISelectProps, 'data'>;
export default function ModelVariantSelect(props: VariantSelectProps) {
const { ...rest } = props;
const { t } = useTranslation();
return (
<IAIMantineSelect
label={t('modelManager.variant')}
data={variantSelectData}
{...rest}
/>
);
}

View File

@ -4,6 +4,8 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
@ -11,6 +13,7 @@ import { FaCopy } from 'react-icons/fa';
export default function UnifiedCanvasCopyToClipboard() {
const isStaging = useAppSelector(isStagingSelector);
const canvasBaseLayer = getCanvasBaseLayer();
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
@ -25,15 +28,22 @@ export default function UnifiedCanvasCopyToClipboard() {
handleCopyImageToClipboard();
},
{
enabled: () => !isStaging,
enabled: () => !isStaging && isClipboardAPIAvailable,
preventDefault: true,
},
[canvasBaseLayer, isProcessing]
[canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
);
const handleCopyImageToClipboard = () => {
const handleCopyImageToClipboard = useCallback(() => {
if (!isClipboardAPIAvailable) {
return;
}
dispatch(canvasCopiedToClipboard());
};
}, [dispatch, isClipboardAPIAvailable]);
if (!isClipboardAPIAvailable) {
return null;
}
return (
<IAIIconButton

View File

@ -0,0 +1,52 @@
import { useAppToaster } from 'app/components/Toaster';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const useCopyImageToClipboard = () => {
const toaster = useAppToaster();
const { t } = useTranslation();
const isClipboardAPIAvailable = useMemo(() => {
return Boolean(navigator.clipboard) && Boolean(window.ClipboardItem);
}, []);
const copyImageToClipboard = useCallback(
async (image_url: string) => {
if (!isClipboardAPIAvailable) {
toaster({
title: t('toast.problemCopyingImage'),
description: "Your browser doesn't support the Clipboard API.",
status: 'error',
duration: 2500,
isClosable: true,
});
}
try {
const response = await fetch(image_url);
const blob = await response.blob();
await navigator.clipboard.write([
new ClipboardItem({
[blob.type]: blob,
}),
]);
toaster({
title: t('toast.imageCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
} catch (err) {
toaster({
title: t('toast.problemCopyingImage'),
description: String(err),
status: 'error',
duration: 2500,
isClosable: true,
});
}
},
[isClipboardAPIAvailable, t, toaster]
);
return { isClipboardAPIAvailable, copyImageToClipboard };
};

View File

@ -1,7 +1,5 @@
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
export type Coordinates = {
x: number;
y: number;
@ -22,7 +20,6 @@ export interface UIState {
shouldUseCanvasBetaLayout: boolean;
shouldShowExistingModelsInSearch: boolean;
shouldUseSliders: boolean;
addNewModelUIOption: AddNewModelType;
shouldHidePreview: boolean;
shouldPinGallery: boolean;
shouldShowGallery: boolean;

View File

@ -5,7 +5,9 @@ import {
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
ConvertModelConfig,
DiffusersModelConfig,
ImportModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
@ -13,8 +15,9 @@ import {
VaeModelConfig,
} from 'services/api/types';
import queryString from 'query-string';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
import { operations, paths } from '../schema';
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
@ -62,6 +65,7 @@ type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
params: ConvertModelConfig;
};
type ConvertMainModelResponse =
@ -75,6 +79,28 @@ type MergeMainModelArg = {
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
body: ImportModelConfig;
};
type ImportMainModelResponse =
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json'];
type AddMainModelArg = {
body: MainModelConfig;
};
type AddMainModelResponse =
paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json'];
export type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse =
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@ -160,6 +186,29 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
importMainModels: build.mutation<
ImportMainModelResponse,
ImportMainModelArg
>({
query: ({ body }) => {
return {
url: `models/import`,
method: 'POST',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
return {
url: `models/add`,
method: 'POST',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
DeleteMainModelArg
@ -176,10 +225,11 @@ export const modelsApi = api.injectEndpoints({
ConvertMainModelResponse,
ConvertMainModelArg
>({
query: ({ base_model, model_name }) => {
query: ({ base_model, model_name, params }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT',
params: params,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
@ -328,6 +378,36 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {
const folderQueryStr = queryString.stringify(arg, {});
return {
url: `/models/search?${folderQueryStr}`,
};
},
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG },
];
if (result) {
tags.push(
...result.map((id) => ({
type: 'ScannedModels' as const,
id,
}))
);
}
return tags;
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
url: `/models/ckpt_confs`,
};
},
}),
}),
});
@ -339,6 +419,10 @@ export const {
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useImportMainModelsMutation,
useAddMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
} = modelsApi;

View File

@ -84,7 +84,7 @@ export type paths = {
delete: operations["del_model"];
/**
* Update Model
* @description Add Model
* @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.
*/
patch: operations["update_model"];
};
@ -102,13 +102,6 @@ export type paths = {
*/
post: operations["add_model"];
};
"/api/v1/models/rename/{base_model}/{model_type}/{model_name}": {
/**
* Rename Model
* @description Rename a model
*/
post: operations["rename_model"];
};
"/api/v1/models/convert/{base_model}/{model_type}/{model_name}": {
/**
* Convert Model
@ -323,7 +316,7 @@ export type components = {
* @description An enumeration.
* @enum {string}
*/
BaseModelType: "sd-1" | "sd-2";
BaseModelType: "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner";
/** BoardChanges */
BoardChanges: {
/**
@ -1226,7 +1219,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
};
/**
* Edges
@ -1269,7 +1262,7 @@ export type components = {
* @description The results of node executions
*/
results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
};
/**
* Errors
@ -2067,6 +2060,12 @@ export type components = {
* @default false
*/
tiled?: boolean;
/**
* Fp32
* @description Decode in full precision
* @default false
*/
fp32?: boolean;
};
/**
* ImageUrlsDTO
@ -2522,6 +2521,12 @@ export type components = {
* @default false
*/
tiled?: boolean;
/**
* Fp32
* @description Decode in full precision
* @default false
*/
fp32?: boolean;
/**
* Metadata
* @description Optional core metadata to be written to the image
@ -3342,7 +3347,7 @@ export type components = {
/** ModelsList */
ModelsList: {
/** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[];
};
/**
* MultiplyInvocation
@ -3770,6 +3775,56 @@ export type components = {
*/
prompt: string;
};
/**
* PromptsFromFileInvocation
* @description Loads prompts from a text file
*/
PromptsFromFileInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default prompt_from_file
* @enum {string}
*/
type?: "prompt_from_file";
/**
* File Path
* @description Path to prompt text file
*/
file_path: string;
/**
* Pre Prompt
* @description String to prepend to each prompt
*/
pre_prompt?: string;
/**
* Post Prompt
* @description String to append to each prompt
*/
post_prompt?: string;
/**
* Start Line
* @description Line in the file to start start from
* @default 1
*/
start_line?: number;
/**
* Max Prompts
* @description Max lines to read from file (0=all)
* @default 1
*/
max_prompts?: number;
};
/**
* RandomIntInvocation
* @description Outputs a single random integer.
@ -4032,10 +4087,10 @@ export type components = {
*/
ResourceOrigin: "internal" | "external";
/**
* RestoreFaceInvocation
* @description Restores faces in an image.
* SDXLCompelPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
RestoreFaceInvocation: {
SDXLCompelPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
@ -4049,21 +4104,514 @@ export type components = {
is_intermediate?: boolean;
/**
* Type
* @default restore_face
* @default sdxl_compel_prompt
* @enum {string}
*/
type?: "restore_face";
type?: "sdxl_compel_prompt";
/**
* Image
* @description The input image
* Prompt
* @description Prompt
* @default
*/
image?: components["schemas"]["ImageField"];
prompt?: string;
/**
* Strength
* @description The strength of the restoration
* @default 0.75
* Style
* @description Style prompt
* @default
*/
strength?: number;
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Target Width
* @default 1024
*/
target_width?: number;
/**
* Target Height
* @default 1024
*/
target_height?: number;
/**
* Clip1
* @description Clip to use
*/
clip1?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLLatentsToLatentsInvocation
* @description Generates latents from conditionings.
*/
SDXLLatentsToLatentsInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default l2l_sdxl
* @enum {string}
*/
type?: "l2l_sdxl";
/**
* Positive Conditioning
* @description Positive conditioning for generation
*/
positive_conditioning?: components["schemas"]["ConditioningField"];
/**
* Negative Conditioning
* @description Negative conditioning for generation
*/
negative_conditioning?: components["schemas"]["ConditioningField"];
/**
* Noise
* @description The noise to use
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
* @description The number of steps to use to generate the image
* @default 10
*/
steps?: number;
/**
* Cfg Scale
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Scheduler
* @description The scheduler to use
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Latents
* @description Initial latents
*/
latents?: components["schemas"]["LatentsField"];
/**
* Denoising Start
* @default 0
*/
denoising_start?: number;
/**
* Denoising End
* @default 1
*/
denoising_end?: number;
};
/**
* SDXLModelLoaderInvocation
* @description Loads an sdxl base model, outputting its submodels.
*/
SDXLModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_model_loader
* @enum {string}
*/
type?: "sdxl_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["MainModelField"];
};
/**
* SDXLModelLoaderOutput
* @description SDXL base model loader output
*/
SDXLModelLoaderOutput: {
/**
* Type
* @default sdxl_model_loader_output
* @enum {string}
*/
type?: "sdxl_model_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip
* @description Tokenizer and text_encoder submodels
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Tokenizer and text_encoder submodels
*/
clip2?: components["schemas"]["ClipField"];
/**
* Vae
* @description Vae submodel
*/
vae?: components["schemas"]["VaeField"];
};
/**
* SDXLRawPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRawPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_raw_prompt
* @enum {string}
*/
type?: "sdxl_raw_prompt";
/**
* Prompt
* @description Prompt
* @default
*/
prompt?: string;
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Target Width
* @default 1024
*/
target_width?: number;
/**
* Target Height
* @default 1024
*/
target_height?: number;
/**
* Clip1
* @description Clip to use
*/
clip1?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLRefinerCompelPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRefinerCompelPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_compel_prompt
* @enum {string}
*/
type?: "sdxl_refiner_compel_prompt";
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Aesthetic Score
* @default 6
*/
aesthetic_score?: number;
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLRefinerModelLoaderInvocation
* @description Loads an sdxl refiner model, outputting its submodels.
*/
SDXLRefinerModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_model_loader
* @enum {string}
*/
type?: "sdxl_refiner_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["MainModelField"];
};
/**
* SDXLRefinerModelLoaderOutput
* @description SDXL refiner model loader output
*/
SDXLRefinerModelLoaderOutput: {
/**
* Type
* @default sdxl_refiner_model_loader_output
* @enum {string}
*/
type?: "sdxl_refiner_model_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip2
* @description Tokenizer and text_encoder submodels
*/
clip2?: components["schemas"]["ClipField"];
/**
* Vae
* @description Vae submodel
*/
vae?: components["schemas"]["VaeField"];
};
/**
* SDXLRefinerRawPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRefinerRawPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_raw_prompt
* @enum {string}
*/
type?: "sdxl_refiner_raw_prompt";
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Aesthetic Score
* @default 6
*/
aesthetic_score?: number;
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLTextToLatentsInvocation
* @description Generates latents from conditionings.
*/
SDXLTextToLatentsInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default t2l_sdxl
* @enum {string}
*/
type?: "t2l_sdxl";
/**
* Positive Conditioning
* @description Positive conditioning for generation
*/
positive_conditioning?: components["schemas"]["ConditioningField"];
/**
* Negative Conditioning
* @description Negative conditioning for generation
*/
negative_conditioning?: components["schemas"]["ConditioningField"];
/**
* Noise
* @description The noise to use
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
* @description The number of steps to use to generate the image
* @default 10
*/
steps?: number;
/**
* Cfg Scale
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Scheduler
* @description The scheduler to use
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Denoising End
* @default 1
*/
denoising_end?: number;
};
/**
* ScaleLatentsInvocation
@ -4267,6 +4815,56 @@ export type components = {
vae?: string;
variant: components["schemas"]["ModelVariantType"];
};
/** StableDiffusionXLModelCheckpointConfig */
StableDiffusionXLModelCheckpointConfig: {
/** Model Name */
model_name: string;
base_model: components["schemas"]["BaseModelType"];
/**
* Model Type
* @enum {string}
*/
model_type: "main";
/** Path */
path: string;
/** Description */
description?: string;
/**
* Model Format
* @enum {string}
*/
model_format: "checkpoint";
error?: components["schemas"]["ModelError"];
/** Vae */
vae?: string;
/** Config */
config: string;
variant: components["schemas"]["ModelVariantType"];
};
/** StableDiffusionXLModelDiffusersConfig */
StableDiffusionXLModelDiffusersConfig: {
/** Model Name */
model_name: string;
base_model: components["schemas"]["BaseModelType"];
/**
* Model Type
* @enum {string}
*/
model_type: "main";
/** Path */
path: string;
/** Description */
description?: string;
/**
* Model Format
* @enum {string}
*/
model_format: "diffusers";
error?: components["schemas"]["ModelError"];
/** Vae */
vae?: string;
variant: components["schemas"]["ModelVariantType"];
};
/**
* StepParamEasingInvocation
* @description Experimental per-step parameter easing for denoising steps
@ -4354,7 +4952,7 @@ export type components = {
* @description An enumeration.
* @enum {string}
*/
SubModelType: "unet" | "text_encoder" | "tokenizer" | "vae" | "scheduler" | "safety_checker";
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "scheduler" | "safety_checker";
/**
* SubtractInvocation
* @description Subtracts two numbers
@ -4659,6 +5257,12 @@ export type components = {
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
@ -4775,7 +5379,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -4812,7 +5416,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -5010,8 +5614,8 @@ export type operations = {
list_models: {
parameters: {
query?: {
/** @description Base model */
base_model?: components["schemas"]["BaseModelType"];
/** @description Base models to include */
base_models?: (components["schemas"]["BaseModelType"])[];
/** @description The type of model to get */
model_type?: components["schemas"]["ModelType"];
};
@ -5061,7 +5665,7 @@ export type operations = {
};
/**
* Update Model
* @description Add Model
* @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.
*/
update_model: {
parameters: {
@ -5076,20 +5680,22 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
responses: {
/** @description The model was updated successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description Bad request */
400: never;
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to the new name */
409: never;
/** @description Validation Error */
422: {
content: {
@ -5112,13 +5718,15 @@ export type operations = {
/** @description The model imported successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to this path or repo_id */
409: never;
/** @description Unrecognized file/folder format */
415: never;
/** @description Validation Error */
422: {
content: {
@ -5136,14 +5744,14 @@ export type operations = {
add_model: {
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
responses: {
/** @description The model added successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description The model could not be found */
@ -5160,46 +5768,6 @@ export type operations = {
424: never;
};
};
/**
* Rename Model
* @description Rename a model
*/
rename_model: {
parameters: {
query?: {
/** @description new model name */
new_name?: string;
/** @description new model base */
new_base?: components["schemas"]["BaseModelType"];
};
path: {
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
/** @description The type of model */
model_type: components["schemas"]["ModelType"];
/** @description current model name */
model_name: string;
};
};
responses: {
/** @description The model was renamed successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to the new name */
409: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/**
* Convert Model
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
@ -5223,7 +5791,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description Bad request */
@ -5312,7 +5880,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description Incompatible models */

View File

@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models
export type ModelType = components['schemas']['ModelType'];
export type SubModelType = components['schemas']['SubModelType'];
export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField'];
@ -46,10 +47,12 @@ export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig'];
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
| components['schemas']['StableDiffusion2ModelDiffusersConfig']
| components['schemas']['StableDiffusionXLModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusionXLModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
@ -57,7 +60,10 @@ export type AnyModelConfig =
| ControlNetModelConfig
| TextualInversionModelConfig
| MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
export type ImportModelConfig = components['schemas']['Body_import_model'];
// Graphs
export type Graph = components['schemas']['Graph'];

View File

@ -5,6 +5,8 @@ import {
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
} from 'services/events/types';
// Common socket action payload data
@ -162,3 +164,35 @@ export const socketGeneratorProgress = createAction<
export const appSocketGeneratorProgress = createAction<
BaseSocketPayload & { data: GeneratorProgressEvent }
>('socket/appSocketGeneratorProgress');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/socketModelLoadStarted');
/**
* App-level Model Load Started
*/
export const appSocketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/appSocketModelLoadStarted');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/socketModelLoadCompleted');
/**
* App-level Model Load Completed
*/
export const appSocketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/appSocketModelLoadCompleted');

View File

@ -1,5 +1,11 @@
import { O } from 'ts-toolbelt';
import { Graph, GraphExecutionState } from '../api/types';
import {
BaseModelType,
Graph,
GraphExecutionState,
ModelType,
SubModelType,
} from '../api/types';
/**
* A progress image, we get one for each step in the generation
@ -25,6 +31,25 @@ export type BaseNode = {
[key: string]: AnyInvocation[keyof AnyInvocation];
};
export type ModelLoadStartedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
};
export type ModelLoadCompletedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
hash?: string;
location: string;
precision: string;
};
/**
* A `generator_progress` socket.io event.
*
@ -101,6 +126,8 @@ export type ServerToClientEvents = {
graph_execution_state_complete: (
payload: GraphExecutionStateCompleteEvent
) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
};
export type ClientToServerEvents = {

View File

@ -11,6 +11,8 @@ import {
socketConnected,
socketDisconnected,
socketSubscribed,
socketModelLoadStarted,
socketModelLoadCompleted,
} from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types';
import { Logger } from 'roarr';
@ -44,7 +46,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
socketSubscribed({
sessionId,
timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId,
boardId: getState().gallery.selectedBoardId,
})
);
}
@ -118,4 +120,28 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
})
);
});
/**
* Model load started
*/
socket.on('model_load_started', (data) => {
dispatch(
socketModelLoadStarted({
data,
timestamp: getTimestamp(),
})
);
});
/**
* Model load completed
*/
socket.on('model_load_completed', (data) => {
dispatch(
socketModelLoadCompleted({
data,
timestamp: getTimestamp(),
})
);
});
};

View File

@ -16,7 +16,7 @@ const invokeAI = defineStyle((props) => {
};
return {
bg: mode('base.200', 'base.600')(props),
bg: mode('base.250', 'base.600')(props),
color: mode('base.850', 'base.100')(props),
borderRadius: 'base',
svg: {

View File

@ -1,6 +1,7 @@
import { menuAnatomy } from '@chakra-ui/anatomy';
import { createMultiStyleConfigHelpers } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools';
import { MotionProps } from 'framer-motion';
const { definePartsStyle, defineMultiStyleConfig } =
createMultiStyleConfigHelpers(menuAnatomy.keys);
@ -21,6 +22,7 @@ const invokeAI = definePartsStyle((props) => ({
},
list: {
zIndex: 9999,
color: mode('base.900', 'base.150')(props),
bg: mode('base.200', 'base.800')(props),
shadow: 'dark-lg',
border: 'none',
@ -35,6 +37,9 @@ const invokeAI = definePartsStyle((props) => ({
_focus: {
bg: mode('base.400', 'base.600')(props),
},
svg: {
opacity: 0.5,
},
},
}));
@ -46,3 +51,28 @@ export const menuTheme = defineMultiStyleConfig({
variant: 'invokeAI',
},
});
export const menuListMotionProps: MotionProps = {
variants: {
enter: {
visibility: 'visible',
opacity: 1,
scale: 1,
transition: {
duration: 0.07,
ease: [0.4, 0, 0.2, 1],
},
},
exit: {
transitionEnd: {
visibility: 'hidden',
},
opacity: 0,
scale: 0.8,
transition: {
duration: 0.07,
easings: 'easeOut',
},
},
},
};

File diff suppressed because it is too large Load Diff

45
pull_request_template.md Normal file
View File

@ -0,0 +1,45 @@
## What type of PR is this? (check all applicable)
- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:
## Description
## Related Tickets & Documents
<!--
For pull requests that relate or close an issue, please include them
below.
For example having the text: "closes #1234" would connect the current pull
request to issue 1234. And when we merge the pull request, Github will
automatically close the issue.
-->
- Related Issue #
- Closes #
## QA Instructions, Screenshots, Recordings
<!--
Please provide steps on how to test changes, any hardware or
software specifications as well as any other pertinent information.
-->
## Added/updated tests?
- [ ] Yes
- [ ] No : _please replace this line with details on why tests
have not been included_
## [optional] Are there any post deployment tasks we need to perform?

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=1.2.1",
"compel==2.0.0rc2",
"controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",