mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/model-events
This commit is contained in:
commit
3e2a948007
19
README.md
19
README.md
@ -132,8 +132,10 @@ and go to http://localhost:9090.
|
|||||||
|
|
||||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||||
|
|
||||||
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are
|
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
|
||||||
not supported.
|
later versions are not supported.
|
||||||
|
Node.js also needs to be installed along with yarn (can be installed with
|
||||||
|
the command `npm install -g yarn` if needed)
|
||||||
|
|
||||||
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
||||||
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
||||||
@ -197,11 +199,18 @@ not supported.
|
|||||||
7. Launch the web server (do it every time you run InvokeAI):
|
7. Launch the web server (do it every time you run InvokeAI):
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
invokeai --web
|
invokeai-web
|
||||||
```
|
```
|
||||||
|
|
||||||
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
8. Build Node.js assets
|
||||||
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
|
||||||
|
```terminal
|
||||||
|
cd invokeai/frontend/web/
|
||||||
|
yarn vite build
|
||||||
|
```
|
||||||
|
|
||||||
|
9. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||||
|
10. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||||
|
|
||||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||||
|
287
docs/features/CONFIGURATION.md
Normal file
287
docs/features/CONFIGURATION.md
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
---
|
||||||
|
title: Configuration
|
||||||
|
---
|
||||||
|
|
||||||
|
# :material-tune-variant: InvokeAI Configuration
|
||||||
|
|
||||||
|
## Intro
|
||||||
|
|
||||||
|
InvokeAI has numerous runtime settings which can be used to adjust
|
||||||
|
many aspects of its operations, including the location of files and
|
||||||
|
directories, memory usage, and performance. These settings can be
|
||||||
|
viewed and customized in several ways:
|
||||||
|
|
||||||
|
1. By editing settings in the `invokeai.yaml` file.
|
||||||
|
2. By setting environment variables.
|
||||||
|
3. On the command-line, when InvokeAI is launched.
|
||||||
|
|
||||||
|
In addition, the most commonly changed settings are accessible
|
||||||
|
graphically via the `invokeai-configure` script.
|
||||||
|
|
||||||
|
### How the Configuration System Works
|
||||||
|
|
||||||
|
When InvokeAI is launched, the very first thing it needs to do is to
|
||||||
|
find its "root" directory, which contains its configuration files,
|
||||||
|
installed models, its database of images, and the folder(s) of
|
||||||
|
generated images themselves. In this document, the root directory will
|
||||||
|
be referred to as ROOT.
|
||||||
|
|
||||||
|
#### Finding the Root Directory
|
||||||
|
|
||||||
|
To find its root directory, InvokeAI uses the following recipe:
|
||||||
|
|
||||||
|
1. It first looks for the argument `--root <path>` on the command line
|
||||||
|
it was launched from, and uses the indicated path if present.
|
||||||
|
|
||||||
|
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
|
||||||
|
the directory path found there if present.
|
||||||
|
|
||||||
|
3. If neither of these are present, then InvokeAI looks for the
|
||||||
|
folder containing the `.venv` Python virtual environment directory for
|
||||||
|
the currently active environment. This directory is checked for files
|
||||||
|
expected inside the InvokeAI root before it is used.
|
||||||
|
|
||||||
|
4. Finally, InvokeAI looks for a directory in the current user's home
|
||||||
|
directory named `invokeai`.
|
||||||
|
|
||||||
|
#### Reading the InvokeAI Configuration File
|
||||||
|
|
||||||
|
Once the root directory has been located, InvokeAI looks for a file
|
||||||
|
named `ROOT/invokeai.yaml`, and if present reads configuration values
|
||||||
|
from it. The top of this file looks like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
InvokeAI:
|
||||||
|
Web Server:
|
||||||
|
host: localhost
|
||||||
|
port: 9090
|
||||||
|
allow_origins: []
|
||||||
|
allow_credentials: true
|
||||||
|
allow_methods:
|
||||||
|
- '*'
|
||||||
|
allow_headers:
|
||||||
|
- '*'
|
||||||
|
Features:
|
||||||
|
esrgan: true
|
||||||
|
internet_available: true
|
||||||
|
log_tokenization: false
|
||||||
|
nsfw_checker: false
|
||||||
|
patchmatch: true
|
||||||
|
restore: true
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
This lines in this file are used to establish default values for
|
||||||
|
Invoke's settings. In the above fragment, the Web Server's listening
|
||||||
|
port is set to 9090 by the `port` setting.
|
||||||
|
|
||||||
|
You can edit this file with a text editor such as "Notepad" (do not
|
||||||
|
use Word or any other word processor). When editing, be careful to
|
||||||
|
maintain the indentation, and do not add extraneous text, as syntax
|
||||||
|
errors will prevent InvokeAI from launching. A basic guide to the
|
||||||
|
format of YAML files can be found
|
||||||
|
[here](https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/).
|
||||||
|
|
||||||
|
You can fix a broken `invokeai.yaml` by deleting it and running the
|
||||||
|
configuration script again -- option [7] in the launcher, "Re-run the
|
||||||
|
configure script".
|
||||||
|
|
||||||
|
#### Reading Environment Variables
|
||||||
|
|
||||||
|
Next InvokeAI looks for defined environment variables in the format
|
||||||
|
`INVOKEAI_<setting_name>`, for example `INVOKEAI_port`. Environment
|
||||||
|
variable values take precedence over configuration file variables. On
|
||||||
|
a Macintosh system, for example, you could change the port that the
|
||||||
|
web server listens on by setting the environment variable this way:
|
||||||
|
|
||||||
|
```
|
||||||
|
export INVOKEAI_port=8000
|
||||||
|
invokeai-web
|
||||||
|
```
|
||||||
|
|
||||||
|
Please check out these
|
||||||
|
[Macintosh](https://phoenixnap.com/kb/set-environment-variable-mac)
|
||||||
|
and
|
||||||
|
[Windows](https://phoenixnap.com/kb/windows-set-environment-variable)
|
||||||
|
guides for setting temporary and permanent environment variables.
|
||||||
|
|
||||||
|
#### Reading the Command Line
|
||||||
|
|
||||||
|
Lastly, InvokeAI takes settings from the command line, which override
|
||||||
|
everything else. The command-line settings have the same name as the
|
||||||
|
corresponding configuration file settings, preceded by a `--`, for
|
||||||
|
example `--port 8000`.
|
||||||
|
|
||||||
|
If you are using the launcher (`invoke.sh` or `invoke.bat`) to launch
|
||||||
|
InvokeAI, then just pass the command-line arguments to the launcher:
|
||||||
|
|
||||||
|
```
|
||||||
|
invoke.bat --port 8000 --host 0.0.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
The arguments will be applied when you select the web server option
|
||||||
|
(and the other options as well).
|
||||||
|
|
||||||
|
If, on the other hand, you prefer to launch InvokeAI directly from the
|
||||||
|
command line, you would first activate the virtual environment (known
|
||||||
|
as the "developer's console" in the launcher), and run `invokeai-web`:
|
||||||
|
|
||||||
|
```
|
||||||
|
> C:\Users\Fred\invokeai\.venv\scripts\activate
|
||||||
|
(.venv) > invokeai-web --port 8000 --host 0.0.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
You can get a listing and brief instructions for each of the
|
||||||
|
command-line options by giving the `--help` argument:
|
||||||
|
|
||||||
|
```
|
||||||
|
(.venv) > invokeai-web --help
|
||||||
|
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
|
||||||
|
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
|
||||||
|
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
|
||||||
|
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
|
||||||
|
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
|
||||||
|
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
|
||||||
|
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
|
||||||
|
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
|
||||||
|
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
|
||||||
|
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
|
||||||
|
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
|
||||||
|
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## The Configuration Settings
|
||||||
|
|
||||||
|
The configuration settings are divided into several distinct
|
||||||
|
groups in `invokeia.yaml`:
|
||||||
|
|
||||||
|
### Web Server
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||||
|
| `port` | `9090` | Network port number that the web server will listen on |
|
||||||
|
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||||
|
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||||
|
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||||
|
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||||
|
|
||||||
|
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
These configuration settings allow you to enable and disable various InvokeAI features:
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
|
||||||
|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||||
|
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||||
|
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
|
||||||
|
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||||
|
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
|
||||||
|
|
||||||
|
### Memory/Performance
|
||||||
|
|
||||||
|
These options tune InvokeAI's memory and performance characteristics.
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
|
||||||
|
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
|
||||||
|
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
|
||||||
|
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
|
||||||
|
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||||
|
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||||
|
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
|
||||||
|
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
|
||||||
|
|
||||||
|
### Paths
|
||||||
|
|
||||||
|
These options set the paths of various directories and files used by
|
||||||
|
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
|
||||||
|
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
|
||||||
|
`autoimport/main`, then the corresponding directory will be located at
|
||||||
|
`/home/fred/invokeai/autoimport/main`.
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
|
||||||
|
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
|
||||||
|
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
|
||||||
|
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
|
||||||
|
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
|
||||||
|
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
|
||||||
|
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
|
||||||
|
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
|
||||||
|
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
|
||||||
|
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
|
||||||
|
|
||||||
|
Note that the autoimport directories will be searched recursively,
|
||||||
|
allowing you to organize the models into folders and subfolders in any
|
||||||
|
way you wish. In addition, while we have split up autoimport
|
||||||
|
directories by the type of model they contain, this isn't
|
||||||
|
necessary. You can combine different model types in the same folder
|
||||||
|
and InvokeAI will figure out what they are. So you can easily use just
|
||||||
|
one autoimport directory by commenting out the unneeded paths:
|
||||||
|
|
||||||
|
```
|
||||||
|
Paths:
|
||||||
|
autoimport_dir: autoimport
|
||||||
|
# lora_dir: null
|
||||||
|
# embedding_dir: null
|
||||||
|
# controlnet_dir: null
|
||||||
|
```
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
|
||||||
|
These settings control the information, warning, and debugging
|
||||||
|
messages printed to the console log while InvokeAI is running:
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
|
||||||
|
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
|
||||||
|
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
|
||||||
|
|
||||||
|
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
|
||||||
|
|
||||||
|
```
|
||||||
|
log_handlers:
|
||||||
|
- console
|
||||||
|
- syslog=localhost
|
||||||
|
- file=/var/log/invokeai.log
|
||||||
|
```
|
||||||
|
|
||||||
|
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
||||||
|
|
||||||
|
* `syslog` is only available on Linux and Macintosh systems. It uses
|
||||||
|
the operating system's "syslog" facility to write log file entries
|
||||||
|
locally or to a remote logging machine. `syslog` offers a variety
|
||||||
|
of configuration options:
|
||||||
|
|
||||||
|
```
|
||||||
|
syslog=/dev/log` - log to the /dev/log device
|
||||||
|
syslog=localhost` - log to the network logger running on the local machine
|
||||||
|
syslog=localhost:512` - same as above, but using a non-standard port
|
||||||
|
syslog=fredserver,facility=LOG_USER,socktype=SOCK_DRAM`
|
||||||
|
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
|
||||||
|
```
|
||||||
|
|
||||||
|
* `http` can be used to log to a remote web server. The server must be
|
||||||
|
properly configured to receive and act on log messages. The option
|
||||||
|
accepts the URL to the web server, and a `method` argument
|
||||||
|
indicating whether the message should be submitted using the GET or
|
||||||
|
POST method.
|
||||||
|
|
||||||
|
```
|
||||||
|
http=http://my.server/path/to/logger,method=POST
|
||||||
|
```
|
||||||
|
|
||||||
|
The `log_format` option provides several alternative formats:
|
||||||
|
|
||||||
|
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
||||||
|
* `plain` - same as above, but monochrome text only
|
||||||
|
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
||||||
|
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
@ -1,4 +1,8 @@
|
|||||||
# Nodes Editor (Experimental Beta)
|
# Nodes Editor (Experimental)
|
||||||
|
|
||||||
|
🚨
|
||||||
|
*The node editor is experimental. We've made it accessible because we use it to develop the application, but we have not addressed the many known rough edges. It's very easy to shoot yourself in the foot, and we cannot offer support for it until it sees full release (ETA v3.1). Everything is subject to change without warning.*
|
||||||
|
🚨
|
||||||
|
|
||||||
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.
|
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.
|
||||||
|
|
||||||
|
@ -153,6 +153,9 @@ This method is recommended for those familiar with running Docker containers
|
|||||||
- [Prompt Syntax](features/PROMPTS.md)
|
- [Prompt Syntax](features/PROMPTS.md)
|
||||||
- [Generating Variations](features/VARIATIONS.md)
|
- [Generating Variations](features/VARIATIONS.md)
|
||||||
|
|
||||||
|
### InvokeAI Configuration
|
||||||
|
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
|
||||||
|
|
||||||
## :octicons-log-16: Important Changes Since Version 2.3
|
## :octicons-log-16: Important Changes Since Version 2.3
|
||||||
|
|
||||||
### Nodes
|
### Nodes
|
||||||
|
@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
@ -20,7 +21,6 @@ from invokeai.version.invokeai_version import __version__
|
|||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from ..services.restoration_services import RestorationServices
|
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.image_file_storage import DiskImageFileStorage
|
from ..services.image_file_storage import DiskImageFileStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -57,8 +57,8 @@ class ApiDependencies:
|
|||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||||
logger.debug(f'InvokeAI version {__version__}')
|
logger.debug(f"InvokeAI version {__version__}")
|
||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
@ -117,7 +117,7 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config,logger),
|
model_manager=ModelManagerService(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
@ -129,7 +129,6 @@ class ApiDependencies:
|
|||||||
),
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config, logger),
|
|
||||||
configuration=config,
|
configuration=config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
@ -13,8 +13,10 @@ from invokeai.backend import BaseModelType, ModelType
|
|||||||
from invokeai.backend.model_management.models import (
|
from invokeai.backend.model_management.models import (
|
||||||
OPENAPI_MODEL_CONFIGS,
|
OPENAPI_MODEL_CONFIGS,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
@ -46,8 +48,9 @@ async def list_models(
|
|||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"description" : "The model was updated successfully"},
|
responses={200: {"description" : "The model was updated successfully"},
|
||||||
|
400: {"description" : "Bad request"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
400: {"description" : "Bad request"}
|
409: {"description" : "There is already a model corresponding to the new name"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
response_model = UpdateModelResponse,
|
response_model = UpdateModelResponse,
|
||||||
@ -58,23 +61,43 @@ async def update_model(
|
|||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> UpdateModelResponse:
|
) -> UpdateModelResponse:
|
||||||
""" Add Model """
|
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
|
# rename operation requested
|
||||||
|
if info.model_name != model_name or info.base_model != base_model:
|
||||||
|
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = model_type,
|
||||||
|
model_name = model_name,
|
||||||
|
new_name = info.model_name,
|
||||||
|
new_base = info.base_model,
|
||||||
|
)
|
||||||
|
logger.debug(f'renaming result = {result}')
|
||||||
|
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
|
||||||
|
model_name = info.model_name
|
||||||
|
base_model = info.base_model
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
model_attributes=info.dict()
|
model_attributes=info.dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||||
except KeyError as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
@ -121,7 +144,7 @@ async def import_model(
|
|||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
|
||||||
except KeyError as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -161,57 +184,13 @@ async def add_model(
|
|||||||
model_type=info.model_type
|
model_type=info.model_type
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
except KeyError as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
@models_router.post(
|
|
||||||
"/rename/{base_model}/{model_type}/{model_name}",
|
|
||||||
operation_id="rename_model",
|
|
||||||
responses= {
|
|
||||||
201: {"description" : "The model was renamed successfully"},
|
|
||||||
404: {"description" : "The model could not be found"},
|
|
||||||
409: {"description" : "There is already a model corresponding to the new name"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
response_model=ImportModelResponse
|
|
||||||
)
|
|
||||||
async def rename_model(
|
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
|
||||||
model_name: str = Path(description="current model name"),
|
|
||||||
new_name: Optional[str] = Query(description="new model name", default=None),
|
|
||||||
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
|
||||||
) -> ImportModelResponse:
|
|
||||||
""" Rename a model"""
|
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
|
||||||
base_model = base_model,
|
|
||||||
model_type = model_type,
|
|
||||||
model_name = model_name,
|
|
||||||
new_name = new_name,
|
|
||||||
new_base = new_base,
|
|
||||||
)
|
|
||||||
logger.debug(result)
|
|
||||||
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=new_name or model_name,
|
|
||||||
base_model=new_base or base_model,
|
|
||||||
model_type=model_type
|
|
||||||
)
|
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
@ -238,9 +217,9 @@ async def delete_model(
|
|||||||
)
|
)
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
except KeyError:
|
except ModelNotFoundException as e:
|
||||||
logger.error(f"Model not found: {model_name}")
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/convert/{base_model}/{model_type}/{model_name}",
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
@ -273,8 +252,8 @@ async def convert_model(
|
|||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
model_type = model_type)
|
model_type = model_type)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except KeyError:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
@ -364,8 +343,55 @@ async def merge_models(
|
|||||||
model_type = ModelType.Main,
|
model_type = ModelType.Main,
|
||||||
)
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except KeyError:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
# The rename operation is now supported by update_model and no longer needs to be
|
||||||
|
# a standalone route.
|
||||||
|
# @models_router.post(
|
||||||
|
# "/rename/{base_model}/{model_type}/{model_name}",
|
||||||
|
# operation_id="rename_model",
|
||||||
|
# responses= {
|
||||||
|
# 201: {"description" : "The model was renamed successfully"},
|
||||||
|
# 404: {"description" : "The model could not be found"},
|
||||||
|
# 409: {"description" : "There is already a model corresponding to the new name"},
|
||||||
|
# },
|
||||||
|
# status_code=201,
|
||||||
|
# response_model=ImportModelResponse
|
||||||
|
# )
|
||||||
|
# async def rename_model(
|
||||||
|
# base_model: BaseModelType = Path(description="Base model"),
|
||||||
|
# model_type: ModelType = Path(description="The type of model"),
|
||||||
|
# model_name: str = Path(description="current model name"),
|
||||||
|
# new_name: Optional[str] = Query(description="new model name", default=None),
|
||||||
|
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||||
|
# ) -> ImportModelResponse:
|
||||||
|
# """ Rename a model"""
|
||||||
|
|
||||||
|
# logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
|
# base_model = base_model,
|
||||||
|
# model_type = model_type,
|
||||||
|
# model_name = model_name,
|
||||||
|
# new_name = new_name,
|
||||||
|
# new_base = new_base,
|
||||||
|
# )
|
||||||
|
# logger.debug(result)
|
||||||
|
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||||
|
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
# model_name=new_name or model_name,
|
||||||
|
# base_model=new_base or base_model,
|
||||||
|
# model_type=model_type
|
||||||
|
# )
|
||||||
|
# return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
# except ModelNotFoundException as e:
|
||||||
|
# logger.error(str(e))
|
||||||
|
# raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
# except ValueError as e:
|
||||||
|
# logger.error(str(e))
|
||||||
|
# raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
@ -39,6 +39,7 @@ from .invocations.baseinvocation import BaseInvocation
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import invokeai.backend.util.hotfixes
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
|
@ -54,10 +54,10 @@ from .services.invocation_services import InvocationServices
|
|||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
from .services.model_manager_service import ModelManagerService
|
from .services.model_manager_service import ModelManagerService
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.restoration_services import RestorationServices
|
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import invokeai.backend.util.hotfixes
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
@ -295,7 +295,6 @@ def invoke_cli():
|
|||||||
),
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
|
||||||
"""Restores faces in an image."""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["restore_face"] = "restore_face"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: Optional[ImageField] = Field(description="The input image")
|
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["restoration", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=None,
|
|
||||||
strength=self.strength, # GFPGAN strength
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
|
||||||
# TODO: can this return multiple results?
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=results[0][0],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
@ -1,48 +1,112 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
from pathlib import Path, PosixPath
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Union, cast
|
||||||
|
|
||||||
|
import cv2 as cv
|
||||||
|
import numpy as np
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from PIL import Image
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
# TODO: Populate this from disk?
|
||||||
|
# TODO: Use model manager to load?
|
||||||
|
REALESRGAN_MODELS = Literal[
|
||||||
|
"RealESRGAN_x4plus.pth",
|
||||||
|
"RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
]
|
||||||
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
|
||||||
"""Upscales an image."""
|
|
||||||
|
|
||||||
# fmt: off
|
class RealESRGANInvocation(BaseInvocation):
|
||||||
type: Literal["upscale"] = "upscale"
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
# Inputs
|
type: Literal["realesrgan"] = "realesrgan"
|
||||||
image: Optional[ImageField] = Field(description="The input image", default=None)
|
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
model_name: REALESRGAN_MODELS = Field(
|
||||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||||
# fmt: on
|
)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["upscaling", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
models_path = context.services.configuration.models_path
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=(self.level, self.strength),
|
rrdbnet_model = None
|
||||||
strength=0.0, # GFPGAN strength
|
netscale = None
|
||||||
save_original=False,
|
esrgan_model_path = None
|
||||||
image_callback=None,
|
|
||||||
|
if self.model_name in [
|
||||||
|
"RealESRGAN_x4plus.pth",
|
||||||
|
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
]:
|
||||||
|
# x4 RRDBNet model
|
||||||
|
rrdbnet_model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
netscale = 4
|
||||||
|
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
|
||||||
|
# x4 RRDBNet model, 6 blocks
|
||||||
|
rrdbnet_model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=6, # 6 blocks
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
netscale = 4
|
||||||
|
# TODO: add x2 models handling?
|
||||||
|
# elif self.model_name in ["RealESRGAN_x2plus"]:
|
||||||
|
# # x2 RRDBNet model
|
||||||
|
# model = RRDBNet(
|
||||||
|
# num_in_ch=3,
|
||||||
|
# num_out_ch=3,
|
||||||
|
# num_feat=64,
|
||||||
|
# num_block=23,
|
||||||
|
# num_grow_ch=32,
|
||||||
|
# scale=2,
|
||||||
|
# )
|
||||||
|
# model_path = Path()
|
||||||
|
# netscale = 2
|
||||||
|
else:
|
||||||
|
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||||
|
context.services.logger.error(msg)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||||
|
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=netscale,
|
||||||
|
model_path=str(models_path / esrgan_model_path),
|
||||||
|
model=rrdbnet_model,
|
||||||
|
half=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
# TODO: can this return multiple results?
|
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||||
|
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
|
||||||
|
# upscaling, you'll need to add a resize node after this one.
|
||||||
|
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||||
|
|
||||||
|
# back to PIL
|
||||||
|
pil_image = Image.fromarray(
|
||||||
|
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
|
||||||
|
).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=pil_image,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
@ -200,7 +200,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
type = get_args(get_type_hints(cls)['type'])[0]
|
type = get_args(get_type_hints(cls)['type'])[0]
|
||||||
field_dict = dict({type:dict()})
|
field_dict = dict({type:dict()})
|
||||||
for name,field in self.__fields__.items():
|
for name,field in self.__fields__.items():
|
||||||
if name in cls._excluded():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||||
value = getattr(self,name)
|
value = getattr(self,name)
|
||||||
@ -271,8 +271,13 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(self)->List[str]:
|
def _excluded(self)->List[str]:
|
||||||
# combination of deprecated parameters and internal ones
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version']
|
return ['type','initconf']
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _excluded_from_yaml(self)->List[str]:
|
||||||
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||||
|
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file_encoding = 'utf-8'
|
env_file_encoding = 'utf-8'
|
||||||
@ -361,7 +366,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
|
||||||
|
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
|
@ -10,10 +10,9 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from invokeai.app.services.restoration_services import RestorationServices
|
|
||||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
from invokeai.app.services.config import InvokeAISettings
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||||
|
|
||||||
@ -24,7 +23,7 @@ class InvocationServices:
|
|||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAISettings"
|
configuration: "InvokeAIAppConfig"
|
||||||
events: "EventServiceBase"
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
@ -34,13 +33,12 @@ class InvocationServices:
|
|||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
restoration: "RestorationServices"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAISettings",
|
configuration: "InvokeAIAppConfig",
|
||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
@ -50,7 +48,6 @@ class InvocationServices:
|
|||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
restoration: "RestorationServices",
|
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
@ -65,4 +62,3 @@ class InvocationServices:
|
|||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.restoration = restoration
|
|
||||||
|
@ -18,6 +18,7 @@ from invokeai.backend.model_management import (
|
|||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
ModelMerger,
|
ModelMerger,
|
||||||
MergeInterpolationMethod,
|
MergeInterpolationMethod,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_search import FindModels
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
|
|
||||||
@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
KeyErrorException if the name does not already exist.
|
ModelNotFoundException if the name does not already exist.
|
||||||
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
@ -447,14 +448,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
KeyError exception if the name does not already exist.
|
ModelNotFoundException exception if the name does not already exist.
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'update model {model_name}')
|
self.logger.debug(f'update model {model_name}')
|
||||||
if not self.model_exists(model_name, base_model, model_type):
|
if not self.model_exists(model_name, base_model, model_type):
|
||||||
raise KeyError(f"Unknown model {model_name}")
|
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||||
|
|
||||||
def del_model(
|
def del_model(
|
||||||
|
@ -1,113 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import torch
|
|
||||||
from typing import types
|
|
||||||
from ...backend.restoration import Restoration
|
|
||||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
|
||||||
|
|
||||||
# This should be a real base class for postprocessing functions,
|
|
||||||
# but right now we just instantiate the existing gfpgan, esrgan
|
|
||||||
# and codeformer functions.
|
|
||||||
class RestorationServices:
|
|
||||||
'''Face restoration and upscaling'''
|
|
||||||
|
|
||||||
def __init__(self,args,logger:types.ModuleType):
|
|
||||||
try:
|
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
if args.restore or args.esrgan:
|
|
||||||
restoration = Restoration()
|
|
||||||
# TODO: redo for new model structure
|
|
||||||
if False and args.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
|
||||||
args.gfpgan_model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Face restoration disabled")
|
|
||||||
if False and args.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
logger.info("Upscaling disabled")
|
|
||||||
else:
|
|
||||||
logger.info("Face restoration and upscaling disabled")
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
|
||||||
self.device = torch.device(choose_torch_device())
|
|
||||||
self.gfpgan = gfpgan
|
|
||||||
self.codeformer = codeformer
|
|
||||||
self.esrgan = esrgan
|
|
||||||
self.logger = logger
|
|
||||||
self.logger.info('Face restoration initialized')
|
|
||||||
|
|
||||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
|
||||||
# esrgan upscaling
|
|
||||||
# TO DO: refactor into separate methods
|
|
||||||
def upscale_and_reconstruct(
|
|
||||||
self,
|
|
||||||
image_list,
|
|
||||||
facetool="gfpgan",
|
|
||||||
upscale=None,
|
|
||||||
upscale_denoise_str=0.75,
|
|
||||||
strength=0.0,
|
|
||||||
codeformer_fidelity=0.75,
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
prefix=None,
|
|
||||||
):
|
|
||||||
results = []
|
|
||||||
for r in image_list:
|
|
||||||
image, seed = r
|
|
||||||
try:
|
|
||||||
if strength > 0:
|
|
||||||
if self.gfpgan is not None or self.codeformer is not None:
|
|
||||||
if facetool == "gfpgan":
|
|
||||||
if self.gfpgan is None:
|
|
||||||
self.logger.info(
|
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
|
||||||
if facetool == "codeformer":
|
|
||||||
if self.codeformer is None:
|
|
||||||
self.logger.info(
|
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cf_device = (
|
|
||||||
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
|
||||||
)
|
|
||||||
image = self.codeformer.process(
|
|
||||||
image=image,
|
|
||||||
strength=strength,
|
|
||||||
device=cf_device,
|
|
||||||
seed=seed,
|
|
||||||
fidelity=codeformer_fidelity,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.info("Face Restoration is disabled.")
|
|
||||||
if upscale is not None:
|
|
||||||
if self.esrgan is not None:
|
|
||||||
if len(upscale) < 2:
|
|
||||||
upscale.append(0.75)
|
|
||||||
image = self.esrgan.process(
|
|
||||||
image,
|
|
||||||
upscale[1],
|
|
||||||
seed,
|
|
||||||
int(upscale[0]),
|
|
||||||
denoise_str=upscale_denoise_str,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.info(
|
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if image_callback is not None:
|
|
||||||
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
|
||||||
else:
|
|
||||||
r[0] = image
|
|
||||||
|
|
||||||
results.append([image, seed])
|
|
||||||
|
|
||||||
return results
|
|
@ -30,8 +30,6 @@ from huggingface_hub import login as hf_hub_login
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
|
||||||
CLIPSegForImageSegmentation,
|
|
||||||
CLIPTextModel,
|
CLIPTextModel,
|
||||||
CLIPTokenizer,
|
CLIPTokenizer,
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
@ -45,7 +43,6 @@ from invokeai.app.services.config import (
|
|||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
SingleSelectColumns,
|
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
IntTitleSlider,
|
IntTitleSlider,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
@ -226,64 +223,30 @@ def download_conversion_models():
|
|||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing models from RealESRGAN...")
|
logger.info("Installing RealESRGAN models...")
|
||||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
URLs = [
|
||||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
dict(
|
||||||
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
description = "RealESRGAN_x4plus.pth",
|
||||||
|
),
|
||||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
dict(
|
||||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
description = "RealESRGAN_x4plus_anime_6B.pth",
|
||||||
def download_gfpgan():
|
),
|
||||||
logger.info("Installing GFPGAN models...")
|
dict(
|
||||||
for model in (
|
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
[
|
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
|
),
|
||||||
],
|
]
|
||||||
[
|
for model in URLs:
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||||
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
|
|
||||||
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
|
|
||||||
],
|
|
||||||
):
|
|
||||||
model_url, model_dest = model[0], config.root_path / model[1]
|
|
||||||
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_codeformer():
|
|
||||||
logger.info("Installing CodeFormer model file...")
|
|
||||||
model_url = (
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
|
||||||
)
|
|
||||||
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
|
|
||||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_clipseg():
|
|
||||||
logger.info("Installing clipseg model for text-based masking...")
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
|
||||||
try:
|
|
||||||
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
|
||||||
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
|
||||||
except Exception:
|
|
||||||
logger.info("Error installing clipseg model:")
|
|
||||||
logger.info(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
def download_support_models():
|
def download_support_models():
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
download_gfpgan()
|
|
||||||
download_codeformer()
|
|
||||||
download_clipseg()
|
|
||||||
download_conversion_models()
|
download_conversion_models()
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
@ -858,9 +821,9 @@ def main():
|
|||||||
download_support_models()
|
download_support_models()
|
||||||
|
|
||||||
if opt.skip_sd_weights:
|
if opt.skip_sd_weights:
|
||||||
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
|
||||||
elif models_to_download:
|
elif models_to_download:
|
||||||
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
|
||||||
process_and_execute(opt, models_to_download)
|
process_and_execute(opt, models_to_download)
|
||||||
|
|
||||||
postscript(errors=errors)
|
postscript(errors=errors)
|
||||||
|
@ -117,6 +117,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = self.mgr.list_models()
|
installed_models = self.mgr.list_models()
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md['base_model']
|
base = md['base_model']
|
||||||
model_type = md['model_type']
|
model_type = md['model_type']
|
||||||
@ -134,6 +135,12 @@ class ModelInstall(object):
|
|||||||
)
|
)
|
||||||
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
||||||
|
|
||||||
|
def list_models(self, model_type):
|
||||||
|
installed = self.mgr.list_models(model_type=model_type)
|
||||||
|
print(f'Installed models of type `{model_type}`:')
|
||||||
|
for i in installed:
|
||||||
|
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
||||||
|
|
||||||
def starter_models(self)->Set[str]:
|
def starter_models(self)->Set[str]:
|
||||||
models = set()
|
models = set()
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
|
@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management
|
|||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
|
||||||
|
@ -552,7 +552,7 @@ class ModelManager(object):
|
|||||||
model_config = self.models.get(model_key)
|
model_config = self.models.get(model_key)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
self.logger.error(f'Unknown model {model_name}')
|
self.logger.error(f'Unknown model {model_name}')
|
||||||
raise KeyError(f'Unknown model {model_name}')
|
raise ModelNotFoundException(f'Unknown model {model_name}')
|
||||||
|
|
||||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
if base_model is not None and cur_base_model != base_model:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
@ -596,7 +596,7 @@ class ModelManager(object):
|
|||||||
model_cfg = self.models.pop(model_key, None)
|
model_cfg = self.models.pop(model_key, None)
|
||||||
|
|
||||||
if model_cfg is None:
|
if model_cfg is None:
|
||||||
raise KeyError(f"Unknown model {model_key}")
|
raise ModelNotFoundException(f"Unknown model {model_key}")
|
||||||
|
|
||||||
# note: it not garantie to release memory(model can has other references)
|
# note: it not garantie to release memory(model can has other references)
|
||||||
cache_ids = self.cache_keys.pop(model_key, [])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
@ -689,7 +689,7 @@ class ModelManager(object):
|
|||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
model_cfg = self.models.get(model_key, None)
|
model_cfg = self.models.get(model_key, None)
|
||||||
if not model_cfg:
|
if not model_cfg:
|
||||||
raise KeyError(f"Unknown model: {model_key}")
|
raise ModelNotFoundException(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
old_path = self.app_config.root_path / model_cfg.path
|
old_path = self.app_config.root_path / model_cfg.path
|
||||||
new_name = new_name or model_name
|
new_name = new_name or model_name
|
||||||
@ -908,7 +908,6 @@ class ModelManager(object):
|
|||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
|
|
||||||
|
|
||||||
class ScanAndImport(ModelSearch):
|
class ScanAndImport(ModelSearch):
|
||||||
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||||
super().__init__(directories, logger)
|
super().__init__(directories, logger)
|
||||||
@ -965,7 +964,7 @@ class ModelManager(object):
|
|||||||
that model.
|
that model.
|
||||||
|
|
||||||
May return the following exceptions:
|
May return the following exceptions:
|
||||||
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
|
||||||
- ValueError - a corresponding model already exists
|
- ValueError - a corresponding model already exists
|
||||||
'''
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
"""
|
|
||||||
Initialization file for the invokeai.backend.restoration package
|
|
||||||
"""
|
|
||||||
from .base import Restoration
|
|
@ -1,45 +0,0 @@
|
|||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Restoration:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_face_restore_models(
|
|
||||||
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
|
|
||||||
):
|
|
||||||
# Load GFPGAN
|
|
||||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
|
||||||
if gfpgan.gfpgan_model_exists:
|
|
||||||
logger.info("GFPGAN Initialized")
|
|
||||||
else:
|
|
||||||
logger.info("GFPGAN Disabled")
|
|
||||||
gfpgan = None
|
|
||||||
|
|
||||||
# Load CodeFormer
|
|
||||||
codeformer = self.load_codeformer()
|
|
||||||
if codeformer.codeformer_model_exists:
|
|
||||||
logger.info("CodeFormer Initialized")
|
|
||||||
else:
|
|
||||||
logger.info("CodeFormer Disabled")
|
|
||||||
codeformer = None
|
|
||||||
|
|
||||||
return gfpgan, codeformer
|
|
||||||
|
|
||||||
# Face Restore Models
|
|
||||||
def load_gfpgan(self, gfpgan_model_path):
|
|
||||||
from .gfpgan import GFPGAN
|
|
||||||
|
|
||||||
return GFPGAN(gfpgan_model_path)
|
|
||||||
|
|
||||||
def load_codeformer(self):
|
|
||||||
from .codeformer import CodeFormerRestoration
|
|
||||||
|
|
||||||
return CodeFormerRestoration()
|
|
||||||
|
|
||||||
# Upscale Models
|
|
||||||
def load_esrgan(self, esrgan_bg_tile=400):
|
|
||||||
from .realesrgan import ESRGAN
|
|
||||||
|
|
||||||
esrgan = ESRGAN(esrgan_bg_tile)
|
|
||||||
logger.info("ESRGAN Initialized")
|
|
||||||
return esrgan
|
|
@ -1,120 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
pretrained_model_url = (
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CodeFormerRestoration:
|
|
||||||
def __init__(
|
|
||||||
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
|
||||||
codeformer_dir = self.globals.root_dir / codeformer_dir
|
|
||||||
self.model_path = codeformer_dir / codeformer_model_path
|
|
||||||
self.codeformer_model_exists = self.model_path.exists()
|
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
|
||||||
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
|
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
from basicsr.utils import img2tensor, tensor2img
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms.functional import normalize
|
|
||||||
|
|
||||||
from .codeformer_arch import CodeFormer
|
|
||||||
|
|
||||||
cf_class = CodeFormer
|
|
||||||
|
|
||||||
cf = cf_class(
|
|
||||||
dim_embd=512,
|
|
||||||
codebook_size=1024,
|
|
||||||
n_head=8,
|
|
||||||
n_layers=9,
|
|
||||||
connect_list=["32", "64", "128", "256"],
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# note that this file should already be downloaded and cached at
|
|
||||||
# this point
|
|
||||||
checkpoint_path = load_file_from_url(
|
|
||||||
url=pretrained_model_url,
|
|
||||||
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
|
||||||
progress=True,
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(checkpoint_path)["params_ema"]
|
|
||||||
cf.load_state_dict(checkpoint)
|
|
||||||
cf.eval()
|
|
||||||
|
|
||||||
image = image.convert("RGB")
|
|
||||||
# Codeformer expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
face_helper = FaceRestoreHelper(
|
|
||||||
upscale_factor=1,
|
|
||||||
use_parse=True,
|
|
||||||
device=device,
|
|
||||||
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
|
|
||||||
)
|
|
||||||
face_helper.clean_all()
|
|
||||||
face_helper.read_image(bgr_image_array)
|
|
||||||
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
|
||||||
face_helper.align_warp_face()
|
|
||||||
|
|
||||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
|
||||||
cropped_face_t = img2tensor(
|
|
||||||
cropped_face / 255.0, bgr2rgb=True, float32=True
|
|
||||||
)
|
|
||||||
normalize(
|
|
||||||
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
|
|
||||||
)
|
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
|
|
||||||
restored_face = tensor2img(
|
|
||||||
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
|
|
||||||
)
|
|
||||||
del output
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except RuntimeError as error:
|
|
||||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
|
||||||
restored_face = cropped_face
|
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
|
||||||
face_helper.add_restored_face(restored_face)
|
|
||||||
|
|
||||||
face_helper.get_inverse_affine(None)
|
|
||||||
|
|
||||||
restored_img = face_helper.paste_faces_to_input_image()
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(restored_img[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if restored_img.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
cf = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,325 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from .vqgan_arch import *
|
|
||||||
|
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
|
||||||
"""Calculate mean and std for adaptive_instance_normalization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feat (Tensor): 4D tensor.
|
|
||||||
eps (float): A small value added to the variance to avoid
|
|
||||||
divide-by-zero. Default: 1e-5.
|
|
||||||
"""
|
|
||||||
size = feat.size()
|
|
||||||
assert len(size) == 4, "The input feature should be 4D tensor."
|
|
||||||
b, c = size[:2]
|
|
||||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
|
||||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
|
||||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
|
||||||
return feat_mean, feat_std
|
|
||||||
|
|
||||||
|
|
||||||
def adaptive_instance_normalization(content_feat, style_feat):
|
|
||||||
"""Adaptive instance normalization.
|
|
||||||
|
|
||||||
Adjust the reference features to have the similar color and illuminations
|
|
||||||
as those in the degradate features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content_feat (Tensor): The reference feature.
|
|
||||||
style_feat (Tensor): The degradate features.
|
|
||||||
"""
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = calc_mean_std(content_feat)
|
|
||||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
|
|
||||||
size
|
|
||||||
)
|
|
||||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.zeros(
|
|
||||||
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
not_mask = ~mask
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack(
|
|
||||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos_y = torch.stack(
|
|
||||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerSALayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model - MLP
|
|
||||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(embed_dim)
|
|
||||||
self.norm2 = nn.LayerNorm(embed_dim)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
tgt,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None,
|
|
||||||
):
|
|
||||||
# self attention
|
|
||||||
tgt2 = self.norm1(tgt)
|
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
||||||
tgt2 = self.self_attn(
|
|
||||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
|
||||||
)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
|
|
||||||
# ffn
|
|
||||||
tgt2 = self.norm2(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
|
|
||||||
class Fuse_sft_block(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.encode_enc = ResBlock(2 * in_ch, out_ch)
|
|
||||||
|
|
||||||
self.scale = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.shift = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, enc_feat, dec_feat, w=1):
|
|
||||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
|
||||||
scale = self.scale(enc_feat)
|
|
||||||
shift = self.shift(enc_feat)
|
|
||||||
residual = w * (dec_feat * scale + shift)
|
|
||||||
out = dec_feat + residual
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class CodeFormer(VQAutoEncoder):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_embd=512,
|
|
||||||
n_head=8,
|
|
||||||
n_layers=9,
|
|
||||||
codebook_size=1024,
|
|
||||||
latent_size=256,
|
|
||||||
connect_list=["32", "64", "128", "256"],
|
|
||||||
fix_modules=["quantize", "generator"],
|
|
||||||
):
|
|
||||||
super(CodeFormer, self).__init__(
|
|
||||||
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if fix_modules is not None:
|
|
||||||
for module in fix_modules:
|
|
||||||
for param in getattr(self, module).parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
self.connect_list = connect_list
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.dim_embd = dim_embd
|
|
||||||
self.dim_mlp = dim_embd * 2
|
|
||||||
|
|
||||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
|
||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
|
||||||
|
|
||||||
# transformer
|
|
||||||
self.ft_layers = nn.Sequential(
|
|
||||||
*[
|
|
||||||
TransformerSALayer(
|
|
||||||
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
|
|
||||||
)
|
|
||||||
for _ in range(self.n_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# logits_predict head
|
|
||||||
self.idx_pred_layer = nn.Sequential(
|
|
||||||
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.channels = {
|
|
||||||
"16": 512,
|
|
||||||
"32": 256,
|
|
||||||
"64": 256,
|
|
||||||
"128": 128,
|
|
||||||
"256": 128,
|
|
||||||
"512": 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
# after second residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_encoder_block = {
|
|
||||||
"512": 2,
|
|
||||||
"256": 5,
|
|
||||||
"128": 8,
|
|
||||||
"64": 11,
|
|
||||||
"32": 14,
|
|
||||||
"16": 18,
|
|
||||||
}
|
|
||||||
# after first residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_generator_block = {
|
|
||||||
"16": 6,
|
|
||||||
"32": 9,
|
|
||||||
"64": 12,
|
|
||||||
"128": 15,
|
|
||||||
"256": 18,
|
|
||||||
"512": 21,
|
|
||||||
}
|
|
||||||
|
|
||||||
# fuse_convs_dict
|
|
||||||
self.fuse_convs_dict = nn.ModuleDict()
|
|
||||||
for f_size in self.connect_list:
|
|
||||||
in_ch = self.channels[f_size]
|
|
||||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
|
|
||||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
|
||||||
# ################### Encoder #####################
|
|
||||||
enc_feat_dict = {}
|
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in out_list:
|
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
|
||||||
|
|
||||||
lq_feat = x
|
|
||||||
# ################# Transformer ###################
|
|
||||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
|
||||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
|
|
||||||
# BCHW -> BC(HW) -> (HW)BC
|
|
||||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
|
|
||||||
query_emb = feat_emb
|
|
||||||
# Transformer encoder
|
|
||||||
for layer in self.ft_layers:
|
|
||||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
|
||||||
|
|
||||||
# output logits
|
|
||||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
|
||||||
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
|
|
||||||
|
|
||||||
if code_only: # for training stage II
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return logits, lq_feat
|
|
||||||
|
|
||||||
# ################# Quantization ###################
|
|
||||||
# if self.training:
|
|
||||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
|
||||||
# # b(hw)c -> bc(hw) -> bchw
|
|
||||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
|
||||||
# ------------
|
|
||||||
soft_one_hot = F.softmax(logits, dim=2)
|
|
||||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
|
||||||
quant_feat = self.quantize.get_codebook_feat(
|
|
||||||
top_idx, shape=[x.shape[0], 16, 16, 256]
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
|
||||||
|
|
||||||
if detach_16:
|
|
||||||
quant_feat = quant_feat.detach() # for training stage III
|
|
||||||
if adain:
|
|
||||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
|
||||||
|
|
||||||
# ################## Generator ####################
|
|
||||||
x = quant_feat
|
|
||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in fuse_list: # fuse after i-th block
|
|
||||||
f_size = str(x.shape[-1])
|
|
||||||
if w > 0:
|
|
||||||
x = self.fuse_convs_dict[f_size](
|
|
||||||
enc_feat_dict[f_size].detach(), x, w
|
|
||||||
)
|
|
||||||
out = x
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return out, logits, lq_feat
|
|
@ -1,84 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
class GFPGAN:
|
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
|
||||||
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
|
||||||
self.model_path = gfpgan_model_path
|
|
||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
|
||||||
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def model_exists(self):
|
|
||||||
return os.path.isfile(self.model_path)
|
|
||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
cwd = os.getcwd()
|
|
||||||
os.chdir(self.globals.root_dir / 'models')
|
|
||||||
try:
|
|
||||||
from gfpgan import GFPGANer
|
|
||||||
|
|
||||||
self.gfpgan = GFPGANer(
|
|
||||||
model_path=self.model_path,
|
|
||||||
upscale=1,
|
|
||||||
arch="clean",
|
|
||||||
channel_multiplier=2,
|
|
||||||
bg_upsampler=None,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
os.chdir(cwd)
|
|
||||||
|
|
||||||
if self.gfpgan is None:
|
|
||||||
logger.warning("WARNING: GFPGAN not initialized.")
|
|
||||||
logger.warning(
|
|
||||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
# GFPGAN expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
_, _, restored_img = self.gfpgan.enhance(
|
|
||||||
bgr_image_array,
|
|
||||||
has_aligned=False,
|
|
||||||
only_center_face=False,
|
|
||||||
paste_back=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(restored_img[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if restored_img.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
self.gfpgan = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,118 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Outcrop(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image,
|
|
||||||
generate, # current generate object
|
|
||||||
):
|
|
||||||
self.image = image
|
|
||||||
self.generate = generate
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
extents: dict,
|
|
||||||
opt, # current options
|
|
||||||
orig_opt, # ones originally used to generate the image
|
|
||||||
image_callback=None,
|
|
||||||
prefix=None,
|
|
||||||
):
|
|
||||||
# grow and mask the image
|
|
||||||
extended_image = self._extend_all(extents)
|
|
||||||
|
|
||||||
# switch samplers temporarily
|
|
||||||
curr_sampler = self.generate.sampler
|
|
||||||
self.generate.sampler_name = opt.sampler_name
|
|
||||||
self.generate._set_scheduler()
|
|
||||||
|
|
||||||
def wrapped_callback(img, seed, **kwargs):
|
|
||||||
preferred_seed = (
|
|
||||||
orig_opt.seed
|
|
||||||
if orig_opt.seed is not None and orig_opt.seed >= 0
|
|
||||||
else seed
|
|
||||||
)
|
|
||||||
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
|
|
||||||
|
|
||||||
result = self.generate.prompt2image(
|
|
||||||
opt.prompt,
|
|
||||||
seed=opt.seed or orig_opt.seed,
|
|
||||||
sampler=self.generate.sampler,
|
|
||||||
steps=opt.steps,
|
|
||||||
cfg_scale=opt.cfg_scale,
|
|
||||||
ddim_eta=self.generate.ddim_eta,
|
|
||||||
width=extended_image.width,
|
|
||||||
height=extended_image.height,
|
|
||||||
init_img=extended_image,
|
|
||||||
strength=0.90,
|
|
||||||
image_callback=wrapped_callback if image_callback else None,
|
|
||||||
seam_size=opt.seam_size or 96,
|
|
||||||
seam_blur=opt.seam_blur or 16,
|
|
||||||
seam_strength=opt.seam_strength or 0.7,
|
|
||||||
seam_steps=20,
|
|
||||||
tile_size=32,
|
|
||||||
color_match=True,
|
|
||||||
force_outpaint=True, # this just stops the warning about erased regions
|
|
||||||
)
|
|
||||||
|
|
||||||
# swap sampler back
|
|
||||||
self.generate.sampler = curr_sampler
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _extend_all(
|
|
||||||
self,
|
|
||||||
extents: dict,
|
|
||||||
) -> Image:
|
|
||||||
"""
|
|
||||||
Extend the image in direction ('top','bottom','left','right') by
|
|
||||||
the indicated value. The image canvas is extended, and the empty
|
|
||||||
rectangular section will be filled with a blurred copy of the
|
|
||||||
adjacent image.
|
|
||||||
"""
|
|
||||||
image = self.image
|
|
||||||
for direction in extents:
|
|
||||||
assert direction in [
|
|
||||||
"top",
|
|
||||||
"left",
|
|
||||||
"bottom",
|
|
||||||
"right",
|
|
||||||
], 'Direction must be one of "top", "left", "bottom", "right"'
|
|
||||||
pixels = extents[direction]
|
|
||||||
# round pixels up to the nearest 64
|
|
||||||
pixels = math.ceil(pixels / 64) * 64
|
|
||||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
|
||||||
image = self._rotate(image, direction)
|
|
||||||
image = self._extend(image, pixels)
|
|
||||||
image = self._rotate(image, direction, reverse=True)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
|
|
||||||
"""
|
|
||||||
Rotates image so that the area to extend is always at the top top.
|
|
||||||
Simplifies logic later. The reverse argument, if true, will undo the
|
|
||||||
previous transpose.
|
|
||||||
"""
|
|
||||||
transposes = {
|
|
||||||
"right": ["ROTATE_90", "ROTATE_270"],
|
|
||||||
"bottom": ["ROTATE_180", "ROTATE_180"],
|
|
||||||
"left": ["ROTATE_270", "ROTATE_90"],
|
|
||||||
}
|
|
||||||
if direction not in transposes:
|
|
||||||
return image
|
|
||||||
transpose = transposes[direction][1 if reverse else 0]
|
|
||||||
return image.transpose(Image.Transpose.__dict__[transpose])
|
|
||||||
|
|
||||||
def _extend(self, image: Image, pixels: int) -> Image:
|
|
||||||
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
|
|
||||||
|
|
||||||
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
|
|
||||||
extended_img.paste(image, box=(0, pixels))
|
|
||||||
|
|
||||||
# now make the top part transparent to use as a mask
|
|
||||||
alpha = extended_img.getchannel("A")
|
|
||||||
alpha.paste(0, (0, 0, extended_img.width, pixels))
|
|
||||||
extended_img.putalpha(alpha)
|
|
||||||
|
|
||||||
return extended_img
|
|
@ -1,102 +0,0 @@
|
|||||||
import math
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from PIL import Image, ImageFilter
|
|
||||||
|
|
||||||
|
|
||||||
class Outpaint(object):
|
|
||||||
def __init__(self, image, generate):
|
|
||||||
self.image = image
|
|
||||||
self.generate = generate
|
|
||||||
|
|
||||||
def process(self, opt, old_opt, image_callback=None, prefix=None):
|
|
||||||
image = self._create_outpaint_image(self.image, opt.out_direction)
|
|
||||||
|
|
||||||
seed = old_opt.seed
|
|
||||||
prompt = old_opt.prompt
|
|
||||||
|
|
||||||
def wrapped_callback(img, seed, **kwargs):
|
|
||||||
image_callback(img, seed, use_prefix=prefix, **kwargs)
|
|
||||||
|
|
||||||
return self.generate.prompt2image(
|
|
||||||
prompt,
|
|
||||||
seed=seed,
|
|
||||||
sampler=self.generate.sampler,
|
|
||||||
steps=opt.steps,
|
|
||||||
cfg_scale=opt.cfg_scale,
|
|
||||||
ddim_eta=self.generate.ddim_eta,
|
|
||||||
width=opt.width,
|
|
||||||
height=opt.height,
|
|
||||||
init_img=image,
|
|
||||||
strength=0.83,
|
|
||||||
image_callback=wrapped_callback,
|
|
||||||
prefix=prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_outpaint_image(self, image, direction_args):
|
|
||||||
assert len(direction_args) in [
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
], "Direction (-D) must have exactly one or two arguments."
|
|
||||||
|
|
||||||
if len(direction_args) == 1:
|
|
||||||
direction = direction_args[0]
|
|
||||||
pixels = None
|
|
||||||
elif len(direction_args) == 2:
|
|
||||||
direction = direction_args[0]
|
|
||||||
pixels = int(direction_args[1])
|
|
||||||
|
|
||||||
assert direction in [
|
|
||||||
"top",
|
|
||||||
"left",
|
|
||||||
"bottom",
|
|
||||||
"right",
|
|
||||||
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
|
|
||||||
|
|
||||||
image = image.convert("RGBA")
|
|
||||||
# we always extend top, but rotate to extend along the requested side
|
|
||||||
if direction == "left":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
elif direction == "bottom":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif direction == "right":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
|
|
||||||
pixels = image.height // 2 if pixels is None else int(pixels)
|
|
||||||
assert (
|
|
||||||
0 < pixels < image.height
|
|
||||||
), "Direction (-D) pixels length must be in the range 0 - image.size"
|
|
||||||
|
|
||||||
# the top part of the image is taken from the source image mirrored
|
|
||||||
# coordinates (0,0) are the upper left corner of an image
|
|
||||||
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
|
|
||||||
top = top.crop((0, top.height - pixels, top.width, top.height))
|
|
||||||
|
|
||||||
# setting all alpha of the top part to 0
|
|
||||||
alpha = top.getchannel("A")
|
|
||||||
alpha.paste(0, (0, 0, top.width, top.height))
|
|
||||||
top.putalpha(alpha)
|
|
||||||
|
|
||||||
# taking the bottom from the original image
|
|
||||||
bottom = image.crop((0, 0, image.width, image.height - pixels))
|
|
||||||
|
|
||||||
new_img = image.copy()
|
|
||||||
new_img.paste(top, (0, 0))
|
|
||||||
new_img.paste(bottom, (0, pixels))
|
|
||||||
|
|
||||||
# create a 10% dither in the middle
|
|
||||||
dither = min(image.height // 10, pixels)
|
|
||||||
for x in range(0, image.width, 2):
|
|
||||||
for y in range(pixels - dither, pixels + dither):
|
|
||||||
(r, g, b, a) = new_img.getpixel((x, y))
|
|
||||||
new_img.putpixel((x, y), (r, g, b, 0))
|
|
||||||
|
|
||||||
# let's rotate back again
|
|
||||||
if direction == "left":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
elif direction == "bottom":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif direction == "right":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
|
|
||||||
return new_img
|
|
@ -1,104 +0,0 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import Image as ImageType
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
class ESRGAN:
|
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
|
||||||
self.bg_tile_size = bg_tile_size
|
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
||||||
use_half_precision = False
|
|
||||||
else:
|
|
||||||
use_half_precision = True
|
|
||||||
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
||||||
|
|
||||||
model = SRVGGNetCompact(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_conv=32,
|
|
||||||
upscale=4,
|
|
||||||
act_type="prelu",
|
|
||||||
)
|
|
||||||
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
|
||||||
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
|
||||||
scale = 4
|
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
|
||||||
scale=scale,
|
|
||||||
model_path=[model_path, wdn_model_path],
|
|
||||||
model=model,
|
|
||||||
tile=self.bg_tile_size,
|
|
||||||
dni_weight=[denoise_str, 1 - denoise_str],
|
|
||||||
tile_pad=10,
|
|
||||||
pre_pad=0,
|
|
||||||
half=use_half_precision,
|
|
||||||
)
|
|
||||||
|
|
||||||
return bg_upsampler
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
image: ImageType,
|
|
||||||
strength: float,
|
|
||||||
seed: str = None,
|
|
||||||
upsampler_scale: int = 2,
|
|
||||||
denoise_str: float = 0.75,
|
|
||||||
):
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
try:
|
|
||||||
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
|
|
||||||
except Exception:
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error("Error loading Real-ESRGAN:")
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if upsampler_scale == 0:
|
|
||||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
|
||||||
return image
|
|
||||||
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(
|
|
||||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
|
||||||
)
|
|
||||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
# REALSRGAN expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
output, _ = upsampler.enhance(
|
|
||||||
bgr_image_array,
|
|
||||||
outscale=upsampler_scale,
|
|
||||||
alpha_upsampler="realesrgan",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(output[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if output.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
upsampler = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,514 +0,0 @@
|
|||||||
"""
|
|
||||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
|
||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
|
||||||
|
|
||||||
"""
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(in_channels):
|
|
||||||
return torch.nn.GroupNorm(
|
|
||||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def swish(x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Define VQVAE classes
|
|
||||||
class VectorQuantizer(nn.Module):
|
|
||||||
def __init__(self, codebook_size, emb_dim, beta):
|
|
||||||
super(VectorQuantizer, self).__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
|
||||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
|
||||||
self.embedding.weight.data.uniform_(
|
|
||||||
-1.0 / self.codebook_size, 1.0 / self.codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
|
||||||
z = z.permute(0, 2, 3, 1).contiguous()
|
|
||||||
z_flattened = z.view(-1, self.emb_dim)
|
|
||||||
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
||||||
d = (
|
|
||||||
(z_flattened**2).sum(dim=1, keepdim=True)
|
|
||||||
+ (self.embedding.weight**2).sum(1)
|
|
||||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
||||||
)
|
|
||||||
|
|
||||||
mean_distance = torch.mean(d)
|
|
||||||
# find closest encodings
|
|
||||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
|
||||||
min_encoding_scores, min_encoding_indices = torch.topk(
|
|
||||||
d, 1, dim=1, largest=False
|
|
||||||
)
|
|
||||||
# [0-1], higher score, higher confidence
|
|
||||||
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
|
|
||||||
|
|
||||||
min_encodings = torch.zeros(
|
|
||||||
min_encoding_indices.shape[0], self.codebook_size
|
|
||||||
).to(z)
|
|
||||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
|
||||||
# compute loss for embedding
|
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
|
||||||
(z_q - z.detach()) ** 2
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
# perplexity
|
|
||||||
e_mean = torch.mean(min_encodings, dim=0)
|
|
||||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return (
|
|
||||||
z_q,
|
|
||||||
loss,
|
|
||||||
{
|
|
||||||
"perplexity": perplexity,
|
|
||||||
"min_encodings": min_encodings,
|
|
||||||
"min_encoding_indices": min_encoding_indices,
|
|
||||||
"min_encoding_scores": min_encoding_scores,
|
|
||||||
"mean_distance": mean_distance,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_codebook_feat(self, indices, shape):
|
|
||||||
# input indices: batch*token_num -> (batch*token_num)*1
|
|
||||||
# shape: batch, height, width, channel
|
|
||||||
indices = indices.view(-1, 1)
|
|
||||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
|
||||||
min_encodings.scatter_(1, indices, 1)
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
|
||||||
|
|
||||||
if shape is not None: # reshape back to match original input shape
|
|
||||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
|
|
||||||
|
|
||||||
class GumbelQuantizer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
codebook_size,
|
|
||||||
emb_dim,
|
|
||||||
num_hiddens,
|
|
||||||
straight_through=False,
|
|
||||||
kl_weight=5e-4,
|
|
||||||
temp_init=1.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.straight_through = straight_through
|
|
||||||
self.temperature = temp_init
|
|
||||||
self.kl_weight = kl_weight
|
|
||||||
self.proj = nn.Conv2d(
|
|
||||||
num_hiddens, codebook_size, 1
|
|
||||||
) # projects last encoder layer to quantized logits
|
|
||||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
hard = self.straight_through if self.training else True
|
|
||||||
|
|
||||||
logits = self.proj(z)
|
|
||||||
|
|
||||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
|
||||||
|
|
||||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
|
||||||
|
|
||||||
# + kl divergence to the prior loss
|
|
||||||
qy = F.softmax(logits, dim=1)
|
|
||||||
diff = (
|
|
||||||
self.kl_weight
|
|
||||||
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
|
||||||
)
|
|
||||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
|
||||||
|
|
||||||
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels=None):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.norm1 = normalize(in_channels)
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
self.norm2 = normalize(out_channels)
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
self.conv_out = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x_in):
|
|
||||||
x = x_in
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x_in = self.conv_out(x_in)
|
|
||||||
|
|
||||||
return x + x_in
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = normalize(in_channels)
|
|
||||||
self.q = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.k = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.v = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.proj_out = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q.shape
|
|
||||||
q = q.reshape(b, c, h * w)
|
|
||||||
q = q.permute(0, 2, 1)
|
|
||||||
k = k.reshape(b, c, h * w)
|
|
||||||
w_ = torch.bmm(q, k)
|
|
||||||
w_ = w_ * (int(c) ** (-0.5))
|
|
||||||
w_ = F.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b, c, h * w)
|
|
||||||
w_ = w_.permute(0, 2, 1)
|
|
||||||
h_ = torch.bmm(v, w_)
|
|
||||||
h_ = h_.reshape(b, c, h, w)
|
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
|
||||||
|
|
||||||
return x + h_
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
nf,
|
|
||||||
emb_dim,
|
|
||||||
ch_mult,
|
|
||||||
num_res_blocks,
|
|
||||||
resolution,
|
|
||||||
attn_resolutions,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
|
|
||||||
curr_res = self.resolution
|
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial convultion
|
|
||||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
|
||||||
|
|
||||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
|
||||||
for i in range(self.num_resolutions):
|
|
||||||
block_in_ch = nf * in_ch_mult[i]
|
|
||||||
block_out_ch = nf * ch_mult[i]
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != self.num_resolutions - 1:
|
|
||||||
blocks.append(Downsample(block_in_ch))
|
|
||||||
curr_res = curr_res // 2
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
# normalise and convert to latent size
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
|
|
||||||
)
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.num_resolutions = len(self.ch_mult)
|
|
||||||
self.num_res_blocks = res_blocks
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.in_channels = emb_dim
|
|
||||||
self.out_channels = 3
|
|
||||||
block_in_ch = self.nf * self.ch_mult[-1]
|
|
||||||
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial conv
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
for i in reversed(range(self.num_resolutions)):
|
|
||||||
block_out_ch = self.nf * self.ch_mult[i]
|
|
||||||
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
|
|
||||||
if curr_res in self.attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != 0:
|
|
||||||
blocks.append(Upsample(block_in_ch))
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(
|
|
||||||
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQAutoEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
img_size,
|
|
||||||
nf,
|
|
||||||
ch_mult,
|
|
||||||
quantizer="nearest",
|
|
||||||
res_blocks=2,
|
|
||||||
attn_resolutions=[16],
|
|
||||||
codebook_size=1024,
|
|
||||||
emb_dim=256,
|
|
||||||
beta=0.25,
|
|
||||||
gumbel_straight_through=False,
|
|
||||||
gumbel_kl_weight=1e-8,
|
|
||||||
model_path=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
logger = get_root_logger()
|
|
||||||
self.in_channels = 3
|
|
||||||
self.nf = nf
|
|
||||||
self.n_blocks = res_blocks
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.embed_dim = emb_dim
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.quantizer_type = quantizer
|
|
||||||
self.encoder = Encoder(
|
|
||||||
self.in_channels,
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions,
|
|
||||||
)
|
|
||||||
if self.quantizer_type == "nearest":
|
|
||||||
self.beta = beta # 0.25
|
|
||||||
self.quantize = VectorQuantizer(
|
|
||||||
self.codebook_size, self.embed_dim, self.beta
|
|
||||||
)
|
|
||||||
elif self.quantizer_type == "gumbel":
|
|
||||||
self.gumbel_num_hiddens = emb_dim
|
|
||||||
self.straight_through = gumbel_straight_through
|
|
||||||
self.kl_weight = gumbel_kl_weight
|
|
||||||
self.quantize = GumbelQuantizer(
|
|
||||||
self.codebook_size,
|
|
||||||
self.embed_dim,
|
|
||||||
self.gumbel_num_hiddens,
|
|
||||||
self.straight_through,
|
|
||||||
self.kl_weight,
|
|
||||||
)
|
|
||||||
self.generator = Generator(
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
|
||||||
if "params_ema" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params_ema"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
|
||||||
elif "params" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Wrong params!")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
|
||||||
x = self.generator(quant)
|
|
||||||
return x, codebook_loss, quant_stats
|
|
||||||
|
|
||||||
|
|
||||||
# patch based discriminator
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQGANDiscriminator(nn.Module):
|
|
||||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
layers = [
|
|
||||||
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
ndf_mult = 1
|
|
||||||
ndf_mult_prev = 1
|
|
||||||
for n in range(1, n_layers): # gradually increase the number of filters
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2**n, 8)
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(
|
|
||||||
ndf * ndf_mult_prev,
|
|
||||||
ndf * ndf_mult,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2**n_layers, 8)
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(
|
|
||||||
ndf * ndf_mult_prev,
|
|
||||||
ndf * ndf_mult,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
|
||||||
] # output 1 channel prediction map
|
|
||||||
self.main = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
|
||||||
if "params_d" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params_d"]
|
|
||||||
)
|
|
||||||
elif "params" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Wrong params!")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.main(x)
|
|
@ -221,7 +221,7 @@ class ControlNetData:
|
|||||||
control_mode: str = Field(default="balanced")
|
control_mode: str = Field(default="balanced")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: torch.Tensor
|
unconditioned_embeddings: torch.Tensor
|
||||||
text_embeddings: torch.Tensor
|
text_embeddings: torch.Tensor
|
||||||
@ -507,6 +507,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < max_len:
|
||||||
|
conditioning_attention_mask = torch.cat([
|
||||||
|
conditioning_attention_mask,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
cond = torch.cat([
|
||||||
|
cond,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = torch.cat([
|
||||||
|
encoder_attention_mask,
|
||||||
|
conditioning_attention_mask,
|
||||||
|
])
|
||||||
|
|
||||||
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
|
||||||
|
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
|
||||||
|
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||||
|
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
|
||||||
|
)
|
||||||
|
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||||
|
conditioning_data.text_embeddings, max_len, encoder_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
@ -546,6 +580,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
@ -603,6 +638,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
|
# TODO: rewrite to pass with conditionings
|
||||||
|
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
@ -649,6 +686,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
|
@ -241,45 +241,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
|
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
|
||||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
|
||||||
|
|
||||||
if cond.shape[1] < max_len:
|
|
||||||
conditioning_attention_mask = torch.cat([
|
|
||||||
conditioning_attention_mask,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
cond = torch.cat([
|
|
||||||
cond,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
|
||||||
encoder_attention_mask = conditioning_attention_mask
|
|
||||||
else:
|
|
||||||
encoder_attention_mask = torch.cat([
|
|
||||||
encoder_attention_mask,
|
|
||||||
conditioning_attention_mask,
|
|
||||||
])
|
|
||||||
|
|
||||||
return cond, encoder_attention_mask
|
|
||||||
|
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
encoder_attention_mask = None
|
|
||||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
|
||||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
|
||||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
|
||||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
|
||||||
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings,
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
@ -24,7 +24,7 @@ import torch.utils.checkpoint
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed, ProjectConfiguration
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
@ -35,7 +35,6 @@ from diffusers.optimization import get_scheduler
|
|||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -47,6 +46,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
# invokeai stuff
|
# invokeai stuff
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
||||||
|
from invokeai.app.services.model_manager_service import ModelManagerService
|
||||||
|
from invokeai.backend.model_management.models import SubModelType
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
PIL_INTERPOLATION = {
|
PIL_INTERPOLATION = {
|
||||||
@ -132,7 +133,7 @@ def parse_args():
|
|||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default="stable-diffusion-1.5",
|
default="sd-1/main/stable-diffusion-v1-5",
|
||||||
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
@ -565,7 +566,6 @@ def do_textual_inversion_training(
|
|||||||
checkpointing_steps: int = 500,
|
checkpointing_steps: int = 500,
|
||||||
resume_from_checkpoint: Path = None,
|
resume_from_checkpoint: Path = None,
|
||||||
enable_xformers_memory_efficient_attention: bool = False,
|
enable_xformers_memory_efficient_attention: bool = False,
|
||||||
root_dir: Path = None,
|
|
||||||
hub_model_id: str = None,
|
hub_model_id: str = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -584,13 +584,17 @@ def do_textual_inversion_training(
|
|||||||
|
|
||||||
logging_dir = output_dir / logging_dir
|
logging_dir = output_dir / logging_dir
|
||||||
|
|
||||||
|
accelerator_config = ProjectConfiguration()
|
||||||
|
accelerator_config.logging_dir = logging_dir
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
log_with=report_to,
|
log_with=report_to,
|
||||||
logging_dir=logging_dir,
|
project_config=accelerator_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_manager = ModelManagerService(config,logger)
|
||||||
|
|
||||||
# Make one log on every process with the configuration for debugging.
|
# Make one log on every process with the configuration for debugging.
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
@ -628,46 +632,46 @@ def do_textual_inversion_training(
|
|||||||
elif output_dir is not None:
|
elif output_dir is not None:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
models_conf = OmegaConf.load(config.model_conf_path)
|
known_models = model_manager.model_names()
|
||||||
model_conf = models_conf.get(model, None)
|
model_name = model.split('/')[-1]
|
||||||
assert model_conf is not None, f"Unknown model: {model}"
|
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
||||||
|
assert model_meta is not None, f"Unknown model: {model}"
|
||||||
|
model_info = model_manager.model_info(*model_meta)
|
||||||
assert (
|
assert (
|
||||||
model_conf.get("format", "diffusers") == "diffusers"
|
model_info['model_format'] == "diffusers"
|
||||||
), "This script only works with models of type 'diffusers'"
|
), "This script only works with models of type 'diffusers'"
|
||||||
pretrained_model_name_or_path = model_conf.get("repo_id", None) or Path(
|
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
||||||
model_conf.get("path")
|
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
||||||
)
|
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
||||||
assert (
|
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||||
pretrained_model_name_or_path
|
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
|
||||||
pipeline_args = dict(cache_dir=config.cache_dir)
|
|
||||||
|
|
||||||
# Load tokenizer
|
pipeline_args = dict(local_files_only=True)
|
||||||
if tokenizer_name:
|
if tokenizer_name:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||||
else:
|
else:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args
|
tokenizer_info.location, subfolder='tokenizer', **pipeline_args
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load scheduler and models
|
# Load scheduler and models
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||||
pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args
|
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
|
||||||
)
|
)
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
text_encoder_info.location,
|
||||||
subfolder="text_encoder",
|
subfolder="text_encoder",
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
)
|
)
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
vae_info.location,
|
||||||
subfolder="vae",
|
subfolder="vae",
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
)
|
)
|
||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
unet_info.location,
|
||||||
subfolder="unet",
|
subfolder="unet",
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
@ -989,7 +993,7 @@ def do_textual_inversion_training(
|
|||||||
save_full_model = not only_save_embeds
|
save_full_model = not only_save_embeds
|
||||||
if save_full_model:
|
if save_full_model:
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
unet_info.location,
|
||||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
|
634
invokeai/backend/util/hotfixes.py
Normal file
634
invokeai/backend/util/hotfixes.py
Normal file
@ -0,0 +1,634 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||||
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
from diffusers.models.unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D,
|
||||||
|
DownBlock2D,
|
||||||
|
UNetMidBlock2DCrossAttn,
|
||||||
|
get_down_block,
|
||||||
|
)
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
|
|
||||||
|
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||||
|
|
||||||
|
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A ControlNet model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`, defaults to 4):
|
||||||
|
The number of channels in the input sample.
|
||||||
|
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||||
|
Whether to flip the sin to cos in the time embedding.
|
||||||
|
freq_shift (`int`, defaults to 0):
|
||||||
|
The frequency shift to apply to the time embedding.
|
||||||
|
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||||
|
The tuple of downsample blocks to use.
|
||||||
|
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||||
|
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||||
|
The tuple of output channels for each block.
|
||||||
|
layers_per_block (`int`, defaults to 2):
|
||||||
|
The number of layers per block.
|
||||||
|
downsample_padding (`int`, defaults to 1):
|
||||||
|
The padding to use for the downsampling convolution.
|
||||||
|
mid_block_scale_factor (`float`, defaults to 1):
|
||||||
|
The scale factor to use for the mid block.
|
||||||
|
act_fn (`str`, defaults to "silu"):
|
||||||
|
The activation function to use.
|
||||||
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||||
|
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||||
|
in post-processing.
|
||||||
|
norm_eps (`float`, defaults to 1e-5):
|
||||||
|
The epsilon to use for the normalization.
|
||||||
|
cross_attention_dim (`int`, defaults to 1280):
|
||||||
|
The dimension of the cross attention features.
|
||||||
|
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||||
|
The dimension of the attention heads.
|
||||||
|
use_linear_projection (`bool`, defaults to `False`):
|
||||||
|
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||||
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||||
|
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||||
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
upcast_attention (`bool`, defaults to `False`):
|
||||||
|
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||||
|
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||||
|
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||||
|
`class_embed_type="projection"`.
|
||||||
|
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||||
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||||
|
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||||
|
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||||
|
global_pool_conditions (`bool`, defaults to `False`):
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 4,
|
||||||
|
conditioning_channels: int = 3,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
freq_shift: int = 0,
|
||||||
|
down_block_types: Tuple[str] = (
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
),
|
||||||
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
downsample_padding: int = 1,
|
||||||
|
mid_block_scale_factor: float = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
norm_num_groups: Optional[int] = 32,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
cross_attention_dim: int = 1280,
|
||||||
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||||
|
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
class_embed_type: Optional[str] = None,
|
||||||
|
num_class_embeds: Optional[int] = None,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
|
controlnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||||
|
global_pool_conditions: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||||
|
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||||
|
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||||
|
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||||
|
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||||
|
# which is why we correct for the naming here.
|
||||||
|
num_attention_heads = num_attention_heads or attention_head_dim
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
if len(block_out_channels) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# input
|
||||||
|
conv_in_kernel = 3
|
||||||
|
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
# time
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||||
|
timestep_input_dim = block_out_channels[0]
|
||||||
|
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
timestep_input_dim,
|
||||||
|
time_embed_dim,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# class embedding
|
||||||
|
if class_embed_type is None and num_class_embeds is not None:
|
||||||
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||||
|
elif class_embed_type == "timestep":
|
||||||
|
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "identity":
|
||||||
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "projection":
|
||||||
|
if projection_class_embeddings_input_dim is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||||
|
)
|
||||||
|
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||||
|
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||||
|
# 2. it projects from an arbitrary input dimension.
|
||||||
|
#
|
||||||
|
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||||
|
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||||
|
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||||
|
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
else:
|
||||||
|
self.class_embedding = None
|
||||||
|
|
||||||
|
# control net conditioning embedding
|
||||||
|
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||||
|
conditioning_embedding_channels=block_out_channels[0],
|
||||||
|
block_out_channels=conditioning_embedding_out_channels,
|
||||||
|
conditioning_channels=conditioning_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
self.controlnet_down_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
if isinstance(only_cross_attention, bool):
|
||||||
|
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(num_attention_heads, int):
|
||||||
|
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
down_block = get_down_block(
|
||||||
|
down_block_type,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[i],
|
||||||
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
for _ in range(layers_per_block):
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
if not is_final_block:
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
mid_block_channel = block_out_channels[-1]
|
||||||
|
|
||||||
|
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_mid_block = controlnet_block
|
||||||
|
|
||||||
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=mid_block_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=mid_block_scale_factor,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unet(
|
||||||
|
cls,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
controlnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||||
|
load_weights_from_unet: bool = True,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unet (`UNet2DConditionModel`):
|
||||||
|
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||||
|
where applicable.
|
||||||
|
"""
|
||||||
|
controlnet = cls(
|
||||||
|
in_channels=unet.config.in_channels,
|
||||||
|
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||||
|
freq_shift=unet.config.freq_shift,
|
||||||
|
down_block_types=unet.config.down_block_types,
|
||||||
|
only_cross_attention=unet.config.only_cross_attention,
|
||||||
|
block_out_channels=unet.config.block_out_channels,
|
||||||
|
layers_per_block=unet.config.layers_per_block,
|
||||||
|
downsample_padding=unet.config.downsample_padding,
|
||||||
|
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||||
|
act_fn=unet.config.act_fn,
|
||||||
|
norm_num_groups=unet.config.norm_num_groups,
|
||||||
|
norm_eps=unet.config.norm_eps,
|
||||||
|
cross_attention_dim=unet.config.cross_attention_dim,
|
||||||
|
attention_head_dim=unet.config.attention_head_dim,
|
||||||
|
num_attention_heads=unet.config.num_attention_heads,
|
||||||
|
use_linear_projection=unet.config.use_linear_projection,
|
||||||
|
class_embed_type=unet.config.class_embed_type,
|
||||||
|
num_class_embeds=unet.config.num_class_embeds,
|
||||||
|
upcast_attention=unet.config.upcast_attention,
|
||||||
|
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||||
|
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||||
|
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||||
|
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_weights_from_unet:
|
||||||
|
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||||
|
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||||
|
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||||
|
|
||||||
|
if controlnet.class_embedding:
|
||||||
|
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||||
|
|
||||||
|
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
||||||
|
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||||
|
|
||||||
|
return controlnet
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.processor
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
self.set_attn_processor(AttnProcessor())
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||||
|
def set_attention_slice(self, slice_size):
|
||||||
|
r"""
|
||||||
|
Enable sliced attention computation.
|
||||||
|
|
||||||
|
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||||
|
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||||
|
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||||
|
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||||
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||||
|
must be a multiple of `slice_size`.
|
||||||
|
"""
|
||||||
|
sliceable_head_dims = []
|
||||||
|
|
||||||
|
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(child)
|
||||||
|
|
||||||
|
# retrieve number of attention layers
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(module)
|
||||||
|
|
||||||
|
num_sliceable_layers = len(sliceable_head_dims)
|
||||||
|
|
||||||
|
if slice_size == "auto":
|
||||||
|
# half the attention head size is usually a good trade-off between
|
||||||
|
# speed and memory
|
||||||
|
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||||
|
elif slice_size == "max":
|
||||||
|
# make smallest slice possible
|
||||||
|
slice_size = num_sliceable_layers * [1]
|
||||||
|
|
||||||
|
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||||
|
|
||||||
|
if len(slice_size) != len(sliceable_head_dims):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||||
|
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(slice_size)):
|
||||||
|
size = slice_size[i]
|
||||||
|
dim = sliceable_head_dims[i]
|
||||||
|
if size is not None and size > dim:
|
||||||
|
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||||
|
|
||||||
|
# Recursively walk through all the children.
|
||||||
|
# Any children which exposes the set_attention_slice method
|
||||||
|
# gets the message
|
||||||
|
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
module.set_attention_slice(slice_size.pop())
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_attention_slice(child, slice_size)
|
||||||
|
|
||||||
|
reversed_slice_size = list(reversed(slice_size))
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
controlnet_cond: torch.FloatTensor,
|
||||||
|
conditioning_scale: float = 1.0,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
guess_mode: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[ControlNetOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
The [`ControlNetModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor.
|
||||||
|
timestep (`Union[torch.Tensor, float, int]`):
|
||||||
|
The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.Tensor`):
|
||||||
|
The encoder hidden states.
|
||||||
|
controlnet_cond (`torch.FloatTensor`):
|
||||||
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
conditioning_scale (`float`, defaults to `1.0`):
|
||||||
|
The scale factor for ControlNet outputs.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||||
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
|
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
guess_mode (`bool`, defaults to `False`):
|
||||||
|
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||||
|
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||||
|
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||||
|
returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# check channel order
|
||||||
|
channel_order = self.config.controlnet_conditioning_channel_order
|
||||||
|
|
||||||
|
if channel_order == "rgb":
|
||||||
|
# in rgb order by default
|
||||||
|
...
|
||||||
|
elif channel_order == "bgr":
|
||||||
|
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||||
|
|
||||||
|
# prepare attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = sample.device.type == "mps"
|
||||||
|
if isinstance(timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
|
||||||
|
if self.class_embedding is not None:
|
||||||
|
if class_labels is None:
|
||||||
|
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||||
|
|
||||||
|
if self.config.class_embed_type == "timestep":
|
||||||
|
class_labels = self.time_proj(class_labels)
|
||||||
|
|
||||||
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||||
|
|
||||||
|
sample = sample + controlnet_cond
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Control net blocks
|
||||||
|
|
||||||
|
controlnet_down_block_res_samples = ()
|
||||||
|
|
||||||
|
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||||
|
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||||
|
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
|
down_block_res_samples = controlnet_down_block_res_samples
|
||||||
|
|
||||||
|
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||||
|
|
||||||
|
# 6. scaling
|
||||||
|
if guess_mode and not self.config.global_pool_conditions:
|
||||||
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||||
|
|
||||||
|
scales = scales * conditioning_scale
|
||||||
|
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||||
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||||
|
else:
|
||||||
|
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||||
|
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||||
|
|
||||||
|
if self.config.global_pool_conditions:
|
||||||
|
down_block_res_samples = [
|
||||||
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||||
|
]
|
||||||
|
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (down_block_res_samples, mid_block_res_sample)
|
||||||
|
|
||||||
|
return ControlNetOutput(
|
||||||
|
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
diffusers.ControlNetModel = ControlNetModel
|
||||||
|
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
@ -58,22 +58,29 @@ sd-1/main/waifu-diffusion:
|
|||||||
recommended: False
|
recommended: False
|
||||||
sd-1/controlnet/canny:
|
sd-1/controlnet/canny:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_canny
|
repo_id: lllyasviel/control_v11p_sd15_canny
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/inpaint:
|
sd-1/controlnet/inpaint:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
||||||
sd-1/controlnet/mlsd:
|
sd-1/controlnet/mlsd:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
||||||
sd-1/controlnet/depth:
|
sd-1/controlnet/depth:
|
||||||
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/normal_bae:
|
sd-1/controlnet/normal_bae:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
||||||
sd-1/controlnet/seg:
|
sd-1/controlnet/seg:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_seg
|
repo_id: lllyasviel/control_v11p_sd15_seg
|
||||||
sd-1/controlnet/lineart:
|
sd-1/controlnet/lineart:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_lineart
|
repo_id: lllyasviel/control_v11p_sd15_lineart
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/lineart_anime:
|
sd-1/controlnet/lineart_anime:
|
||||||
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||||
|
sd-1/controlnet/openpose:
|
||||||
|
repo_id: lllyasviel/control_v11p_sd15_openpose
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/scribble:
|
sd-1/controlnet/scribble:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_scribble
|
repo_id: lllyasviel/control_v11p_sd15_scribble
|
||||||
|
recommended: False
|
||||||
sd-1/controlnet/softedge:
|
sd-1/controlnet/softedge:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_softedge
|
repo_id: lllyasviel/control_v11p_sd15_softedge
|
||||||
sd-1/controlnet/shuffle:
|
sd-1/controlnet/shuffle:
|
||||||
@ -84,6 +91,7 @@ sd-1/controlnet/ip2p:
|
|||||||
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
||||||
sd-1/embedding/EasyNegative:
|
sd-1/embedding/EasyNegative:
|
||||||
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||||
|
recommended: True
|
||||||
sd-1/embedding/ahx-beta-453407d:
|
sd-1/embedding/ahx-beta-453407d:
|
||||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||||
sd-1/lora/LowRA:
|
sd-1/lora/LowRA:
|
||||||
|
@ -256,6 +256,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
widgets = dict()
|
widgets = dict()
|
||||||
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
|
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
|
||||||
model_labels = [self.model_labels[x] for x in model_list]
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
|
show_recommended = len(self.installed_models)==0
|
||||||
if len(model_list) > 0:
|
if len(model_list) > 0:
|
||||||
max_width = max([len(x) for x in model_labels])
|
max_width = max([len(x) for x in model_labels])
|
||||||
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
||||||
@ -280,7 +282,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
value=[
|
value=[
|
||||||
model_list.index(x)
|
model_list.index(x)
|
||||||
for x in model_list
|
for x in model_list
|
||||||
if self.all_models[x].installed
|
if (show_recommended and self.all_models[x].recommended) \
|
||||||
|
or self.all_models[x].installed
|
||||||
],
|
],
|
||||||
max_height=len(model_list)//columns + 1,
|
max_height=len(model_list)//columns + 1,
|
||||||
relx=4,
|
relx=4,
|
||||||
@ -672,7 +675,9 @@ def select_and_download_models(opt: Namespace):
|
|||||||
# pass
|
# pass
|
||||||
|
|
||||||
installer = ModelInstall(config, prediction_type_helper=helper)
|
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||||
if opt.add or opt.delete:
|
if opt.list_models:
|
||||||
|
installer.list_models(opt.list_models)
|
||||||
|
elif opt.add or opt.delete:
|
||||||
selections = InstallSelections(
|
selections = InstallSelections(
|
||||||
install_models = opt.add or [],
|
install_models = opt.add or [],
|
||||||
remove_models = opt.delete or []
|
remove_models = opt.delete or []
|
||||||
@ -745,7 +750,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list-models",
|
"--list-models",
|
||||||
choices=["diffusers","loras","controlnets","tis"],
|
choices=[x.value for x in ModelType],
|
||||||
help="list installed models",
|
help="list installed models",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -773,7 +778,7 @@ def main():
|
|||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
if not (config.conf_path / 'models.yaml').exists():
|
if not config.model_conf_path.exists():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||||
)
|
)
|
||||||
|
355
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
355
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
@ -75,11 +75,6 @@ export type paths = {
|
|||||||
* @description Gets a list of models
|
* @description Gets a list of models
|
||||||
*/
|
*/
|
||||||
get: operations["list_models"];
|
get: operations["list_models"];
|
||||||
/**
|
|
||||||
* Import Model
|
|
||||||
* @description Add a model using its local path, repo_id, or remote URL
|
|
||||||
*/
|
|
||||||
post: operations["import_model"];
|
|
||||||
};
|
};
|
||||||
"/api/v1/models/{base_model}/{model_type}/{model_name}": {
|
"/api/v1/models/{base_model}/{model_type}/{model_name}": {
|
||||||
/**
|
/**
|
||||||
@ -93,13 +88,53 @@ export type paths = {
|
|||||||
*/
|
*/
|
||||||
patch: operations["update_model"];
|
patch: operations["update_model"];
|
||||||
};
|
};
|
||||||
|
"/api/v1/models/import": {
|
||||||
|
/**
|
||||||
|
* Import Model
|
||||||
|
* @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically
|
||||||
|
*/
|
||||||
|
post: operations["import_model"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/add": {
|
||||||
|
/**
|
||||||
|
* Add Model
|
||||||
|
* @description Add a model using the configuration information appropriate for its type. Only local models can be added by path
|
||||||
|
*/
|
||||||
|
post: operations["add_model"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/rename/{base_model}/{model_type}/{model_name}": {
|
||||||
|
/**
|
||||||
|
* Rename Model
|
||||||
|
* @description Rename a model
|
||||||
|
*/
|
||||||
|
post: operations["rename_model"];
|
||||||
|
};
|
||||||
"/api/v1/models/convert/{base_model}/{model_type}/{model_name}": {
|
"/api/v1/models/convert/{base_model}/{model_type}/{model_name}": {
|
||||||
/**
|
/**
|
||||||
* Convert Model
|
* Convert Model
|
||||||
* @description Convert a checkpoint model into a diffusers model
|
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
|
||||||
*/
|
*/
|
||||||
put: operations["convert_model"];
|
put: operations["convert_model"];
|
||||||
};
|
};
|
||||||
|
"/api/v1/models/search": {
|
||||||
|
/** Search For Models */
|
||||||
|
get: operations["search_for_models"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/ckpt_confs": {
|
||||||
|
/**
|
||||||
|
* List Ckpt Configs
|
||||||
|
* @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.
|
||||||
|
*/
|
||||||
|
get: operations["list_ckpt_configs"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/sync": {
|
||||||
|
/**
|
||||||
|
* Sync To Config
|
||||||
|
* @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
|
* in-memory data structures with disk data structures.
|
||||||
|
*/
|
||||||
|
get: operations["sync_to_config"];
|
||||||
|
};
|
||||||
"/api/v1/models/merge/{base_model}": {
|
"/api/v1/models/merge/{base_model}": {
|
||||||
/**
|
/**
|
||||||
* Merge Models
|
* Merge Models
|
||||||
@ -397,6 +432,11 @@ export type components = {
|
|||||||
* @default false
|
* @default false
|
||||||
*/
|
*/
|
||||||
force?: boolean;
|
force?: boolean;
|
||||||
|
/**
|
||||||
|
* Merge Dest Directory
|
||||||
|
* @description Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
|
*/
|
||||||
|
merge_dest_directory?: string;
|
||||||
};
|
};
|
||||||
/** Body_remove_board_image */
|
/** Body_remove_board_image */
|
||||||
Body_remove_board_image: {
|
Body_remove_board_image: {
|
||||||
@ -1186,7 +1226,7 @@ export type components = {
|
|||||||
* @description The nodes in this graph
|
* @description The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: {
|
nodes?: {
|
||||||
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Edges
|
* Edges
|
||||||
@ -3302,7 +3342,7 @@ export type components = {
|
|||||||
/** ModelsList */
|
/** ModelsList */
|
||||||
ModelsList: {
|
ModelsList: {
|
||||||
/** Models */
|
/** Models */
|
||||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
|
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MultiplyInvocation
|
* MultiplyInvocation
|
||||||
@ -3893,6 +3933,41 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
step?: number;
|
step?: number;
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* RealESRGANInvocation
|
||||||
|
* @description Upscales an image using RealESRGAN.
|
||||||
|
*/
|
||||||
|
RealESRGANInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default realesrgan
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "realesrgan";
|
||||||
|
/**
|
||||||
|
* Image
|
||||||
|
* @description The input image
|
||||||
|
*/
|
||||||
|
image?: components["schemas"]["ImageField"];
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description The Real-ESRGAN model to use
|
||||||
|
* @default RealESRGAN_x4plus.pth
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
model_name?: "RealESRGAN_x4plus.pth" | "RealESRGAN_x4plus_anime_6B.pth" | "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth";
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* ResizeLatentsInvocation
|
* ResizeLatentsInvocation
|
||||||
* @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.
|
* @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.
|
||||||
@ -4452,47 +4527,6 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
loras: (components["schemas"]["LoraInfo"])[];
|
loras: (components["schemas"]["LoraInfo"])[];
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* UpscaleInvocation
|
|
||||||
* @description Upscales an image.
|
|
||||||
*/
|
|
||||||
UpscaleInvocation: {
|
|
||||||
/**
|
|
||||||
* Id
|
|
||||||
* @description The id of this node. Must be unique among all nodes.
|
|
||||||
*/
|
|
||||||
id: string;
|
|
||||||
/**
|
|
||||||
* Is Intermediate
|
|
||||||
* @description Whether or not this node is an intermediate node.
|
|
||||||
* @default false
|
|
||||||
*/
|
|
||||||
is_intermediate?: boolean;
|
|
||||||
/**
|
|
||||||
* Type
|
|
||||||
* @default upscale
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
type?: "upscale";
|
|
||||||
/**
|
|
||||||
* Image
|
|
||||||
* @description The input image
|
|
||||||
*/
|
|
||||||
image?: components["schemas"]["ImageField"];
|
|
||||||
/**
|
|
||||||
* Strength
|
|
||||||
* @description The strength
|
|
||||||
* @default 0.75
|
|
||||||
*/
|
|
||||||
strength?: number;
|
|
||||||
/**
|
|
||||||
* Level
|
|
||||||
* @description The upscale level
|
|
||||||
* @default 2
|
|
||||||
* @enum {integer}
|
|
||||||
*/
|
|
||||||
level?: 2 | 4;
|
|
||||||
};
|
|
||||||
/**
|
/**
|
||||||
* VAEModelField
|
* VAEModelField
|
||||||
* @description Vae model field
|
* @description Vae model field
|
||||||
@ -4619,18 +4653,18 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* StableDiffusion1ModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
|
||||||
/**
|
/**
|
||||||
* StableDiffusion2ModelFormat
|
* StableDiffusion2ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusion1ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
@ -4741,7 +4775,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4778,7 +4812,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4997,37 +5031,6 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* Import Model
|
|
||||||
* @description Add a model using its local path, repo_id, or remote URL
|
|
||||||
*/
|
|
||||||
import_model: {
|
|
||||||
requestBody: {
|
|
||||||
content: {
|
|
||||||
"application/json": components["schemas"]["Body_import_model"];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
responses: {
|
|
||||||
/** @description The model imported successfully */
|
|
||||||
201: {
|
|
||||||
content: {
|
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/** @description The model could not be found */
|
|
||||||
404: never;
|
|
||||||
/** @description There is already a model corresponding to this path or repo_id */
|
|
||||||
409: never;
|
|
||||||
/** @description Validation Error */
|
|
||||||
422: {
|
|
||||||
content: {
|
|
||||||
"application/json": components["schemas"]["HTTPValidationError"];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/** @description The model appeared to import successfully, but could not be found in the model manager */
|
|
||||||
424: never;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/**
|
/**
|
||||||
* Delete Model
|
* Delete Model
|
||||||
* @description Delete Model
|
* @description Delete Model
|
||||||
@ -5044,12 +5047,6 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description Successful Response */
|
|
||||||
200: {
|
|
||||||
content: {
|
|
||||||
"application/json": unknown;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/** @description Model deleted successfully */
|
/** @description Model deleted successfully */
|
||||||
204: never;
|
204: never;
|
||||||
/** @description Model not found */
|
/** @description Model not found */
|
||||||
@ -5079,14 +5076,14 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description The model was updated successfully */
|
/** @description The model was updated successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -5101,12 +5098,118 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* Import Model
|
||||||
|
* @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically
|
||||||
|
*/
|
||||||
|
import_model: {
|
||||||
|
requestBody: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["Body_import_model"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description The model imported successfully */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model could not be found */
|
||||||
|
404: never;
|
||||||
|
/** @description There is already a model corresponding to this path or repo_id */
|
||||||
|
409: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model appeared to import successfully, but could not be found in the model manager */
|
||||||
|
424: never;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Add Model
|
||||||
|
* @description Add a model using the configuration information appropriate for its type. Only local models can be added by path
|
||||||
|
*/
|
||||||
|
add_model: {
|
||||||
|
requestBody: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description The model added successfully */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model could not be found */
|
||||||
|
404: never;
|
||||||
|
/** @description There is already a model corresponding to this path or repo_id */
|
||||||
|
409: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model appeared to add successfully, but could not be found in the model manager */
|
||||||
|
424: never;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Rename Model
|
||||||
|
* @description Rename a model
|
||||||
|
*/
|
||||||
|
rename_model: {
|
||||||
|
parameters: {
|
||||||
|
query?: {
|
||||||
|
/** @description new model name */
|
||||||
|
new_name?: string;
|
||||||
|
/** @description new model base */
|
||||||
|
new_base?: components["schemas"]["BaseModelType"];
|
||||||
|
};
|
||||||
|
path: {
|
||||||
|
/** @description Base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
/** @description The type of model */
|
||||||
|
model_type: components["schemas"]["ModelType"];
|
||||||
|
/** @description current model name */
|
||||||
|
model_name: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description The model was renamed successfully */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model could not be found */
|
||||||
|
404: never;
|
||||||
|
/** @description There is already a model corresponding to the new name */
|
||||||
|
409: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* Convert Model
|
* Convert Model
|
||||||
* @description Convert a checkpoint model into a diffusers model
|
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
|
||||||
*/
|
*/
|
||||||
convert_model: {
|
convert_model: {
|
||||||
parameters: {
|
parameters: {
|
||||||
|
query?: {
|
||||||
|
/** @description Save the converted model to the designated directory */
|
||||||
|
convert_dest_directory?: string;
|
||||||
|
};
|
||||||
path: {
|
path: {
|
||||||
/** @description Base model */
|
/** @description Base model */
|
||||||
base_model: components["schemas"]["BaseModelType"];
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
@ -5120,7 +5223,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -5135,6 +5238,60 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
/** Search For Models */
|
||||||
|
search_for_models: {
|
||||||
|
parameters: {
|
||||||
|
query: {
|
||||||
|
/** @description Directory path to search for models */
|
||||||
|
search_path: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description Directory searched successfully */
|
||||||
|
200: {
|
||||||
|
content: {
|
||||||
|
"application/json": (string)[];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description Invalid directory path */
|
||||||
|
404: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* List Ckpt Configs
|
||||||
|
* @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.
|
||||||
|
*/
|
||||||
|
list_ckpt_configs: {
|
||||||
|
responses: {
|
||||||
|
/** @description paths retrieved successfully */
|
||||||
|
200: {
|
||||||
|
content: {
|
||||||
|
"application/json": (string)[];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Sync To Config
|
||||||
|
* @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
|
* in-memory data structures with disk data structures.
|
||||||
|
*/
|
||||||
|
sync_to_config: {
|
||||||
|
responses: {
|
||||||
|
/** @description synchronization successful */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": unknown;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* Merge Models
|
* Merge Models
|
||||||
* @description Convert a checkpoint model into a diffusers model
|
* @description Convert a checkpoint model into a diffusers model
|
||||||
@ -5155,7 +5312,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Incompatible models */
|
/** @description Incompatible models */
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "3.0.0+b5"
|
__version__ = "3.0.0+b6"
|
||||||
|
@ -55,7 +55,6 @@ def mock_services() -> InvocationServices:
|
|||||||
),
|
),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None, # type: ignore
|
|
||||||
configuration = None, # type: ignore
|
configuration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
|
|||||||
),
|
),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None, # type: ignore
|
|
||||||
configuration = None, # type: ignore
|
configuration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
|
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
|
||||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
from invokeai.app.invocations.upscale import RealESRGANInvocation
|
||||||
from invokeai.app.invocations.image import *
|
from invokeai.app.invocations.image import *
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
from invokeai.app.invocations.params import ParamIntInvocation
|
from invokeai.app.invocations.params import ParamIntInvocation
|
||||||
@ -19,7 +19,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
|||||||
def test_connections_are_compatible():
|
def test_connections_are_compatible():
|
||||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = RealESRGANInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
|
|
||||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||||
@ -29,7 +29,7 @@ def test_connections_are_compatible():
|
|||||||
def test_connections_are_incompatible():
|
def test_connections_are_incompatible():
|
||||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = RealESRGANInvocation(id = "2")
|
||||||
to_field = "strength"
|
to_field = "strength"
|
||||||
|
|
||||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||||
@ -39,7 +39,7 @@ def test_connections_are_incompatible():
|
|||||||
def test_connections_incompatible_with_invalid_fields():
|
def test_connections_incompatible_with_invalid_fields():
|
||||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "invalid_field"
|
from_field = "invalid_field"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = RealESRGANInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
|
|
||||||
# From field is invalid
|
# From field is invalid
|
||||||
@ -86,10 +86,10 @@ def test_graph_fails_to_update_node_if_type_changes():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
nu = UpscaleInvocation(id = "1")
|
nu = RealESRGANInvocation(id = "1")
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
@ -98,7 +98,7 @@ def test_graph_allows_non_conflicting_id_change():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge(n.id,"image",n2.id,"image")
|
e1 = create_edge(n.id,"image",n2.id,"image")
|
||||||
g.add_edge(e1)
|
g.add_edge(e1)
|
||||||
@ -128,7 +128,7 @@ def test_graph_fails_to_update_node_id_if_conflict():
|
|||||||
def test_graph_adds_edge():
|
def test_graph_adds_edge():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
@ -139,7 +139,7 @@ def test_graph_adds_edge():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_cycle():
|
def test_graph_fails_to_add_edge_with_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = UpscaleInvocation(id = "1")
|
n1 = RealESRGANInvocation(id = "1")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
e = create_edge(n1.id,"image",n1.id,"image")
|
e = create_edge(n1.id,"image",n1.id,"image")
|
||||||
with pytest.raises(InvalidEdgeError):
|
with pytest.raises(InvalidEdgeError):
|
||||||
@ -148,8 +148,8 @@ def test_graph_fails_to_add_edge_with_cycle():
|
|||||||
def test_graph_fails_to_add_edge_with_long_cycle():
|
def test_graph_fails_to_add_edge_with_long_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = RealESRGANInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
g.add_node(n3)
|
g.add_node(n3)
|
||||||
@ -164,7 +164,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
|
|||||||
def test_graph_fails_to_add_edge_with_missing_node_id():
|
def test_graph_fails_to_add_edge_with_missing_node_id():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge("1","image","3","image")
|
e1 = create_edge("1","image","3","image")
|
||||||
@ -177,8 +177,8 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
|
|||||||
def test_graph_fails_to_add_edge_when_destination_exists():
|
def test_graph_fails_to_add_edge_when_destination_exists():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = RealESRGANInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
g.add_node(n3)
|
g.add_node(n3)
|
||||||
@ -194,7 +194,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
|
|||||||
def test_graph_fails_to_add_edge_with_mismatched_types():
|
def test_graph_fails_to_add_edge_with_mismatched_types():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge("1","image","2","strength")
|
e1 = create_edge("1","image","2","strength")
|
||||||
@ -344,7 +344,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
|
|||||||
def test_graph_validates():
|
def test_graph_validates():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
@ -377,8 +377,8 @@ def test_graph_invalid_if_subgraph_invalid():
|
|||||||
|
|
||||||
def test_graph_invalid_if_has_cycle():
|
def test_graph_invalid_if_has_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = UpscaleInvocation(id = "1")
|
n1 = RealESRGANInvocation(id = "1")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
g.nodes[n2.id] = n2
|
g.nodes[n2.id] = n2
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
@ -391,7 +391,7 @@ def test_graph_invalid_if_has_cycle():
|
|||||||
def test_graph_invalid_with_invalid_connection():
|
def test_graph_invalid_with_invalid_connection():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
g.nodes[n2.id] = n2
|
g.nodes[n2.id] = n2
|
||||||
e1 = create_edge("1","image","2","strength")
|
e1 = create_edge("1","image","2","strength")
|
||||||
@ -503,7 +503,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
|
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
with pytest.raises(NodeNotFoundError):
|
with pytest.raises(NodeNotFoundError):
|
||||||
@ -512,7 +512,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
def test_graph_gets_networkx_graph():
|
def test_graph_gets_networkx_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
@ -529,7 +529,7 @@ def test_graph_gets_networkx_graph():
|
|||||||
def test_graph_can_serialize():
|
def test_graph_can_serialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
@ -541,7 +541,7 @@ def test_graph_can_serialize():
|
|||||||
def test_graph_can_deserialize():
|
def test_graph_can_deserialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
|
Loading…
Reference in New Issue
Block a user