Compare commits

...

231 Commits

Author SHA1 Message Date
efabf250d7 Merge branch 'main' into Convert-Model-Endpoint 2023-05-18 18:51:38 -04:00
7025c00581 Add configuration system, remove legacy globals, args, generate and CLI (#3340)
# Application-wide configuration service

This PR creates a new `InvokeAIAppConfig` object that reads
application-wide settings from an init file, the environment, and the
command line.

Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that
has a top-level key of "InvokeAI" and subheadings for each of the
categories returned by `invokeai --help`.

The file looks like this:

[file: invokeai.yaml]
```
InvokeAI:
  Paths:
    root: /home/lstein/invokeai-main
    conf_path: configs/models.yaml
    legacy_conf_dir: configs/stable-diffusion
    outdir: outputs
    embedding_dir: embeddings
    lora_dir: loras
    autoconvert_dir: null
    gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
  Models:
    model: stable-diffusion-1.5
    embeddings: true
  Memory/Performance:
    xformers_enabled: false
    sequential_guidance: false
    precision: float16
    max_loaded_models: 4
    always_use_cpu: false
    free_gpu_mem: false
  Features:
    nsfw_checker: true
    restore: true
    esrgan: true
    patchmatch: true
    internet_available: true
    log_tokenization: false
  Cross-Origin Resource Sharing:
    allow_origins: []
    allow_credentials: true
    allow_methods:
    - '*'
    allow_headers:
    - '*'
  Web Server:
    host: 127.0.0.1
    port: 8081

```

The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can use any OmegaConf dictionary by passing it to
the config object at initialization time:

```
 omegaconf = OmegaConf.load('/tmp/init.yaml')
 conf = InvokeAIAppConfig(conf=omegaconf)
```
The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can replace supersede this by providing
anyOmegaConf dictionary object initialization time:

```
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf)
```

By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:

```
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
```

It is also possible to set a value at initialization time. This value
has highest priority.
```
conf = InvokeAIAppConfig(xformers_enabled=True)
```
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in:

```
export INVOKEAI_port=8080
```

Order of precedence (from highest):
   1) initialization options
   2) command line options
   3) environment variable options
   4) config file options
   5) pydantic defaults

Typical usage:

```
from invokeai.app.services.config import InvokeAIAppConfig

# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig()
print(conf.nsfw_checker)
```
Finally, the configuration object is able to recreate its (modified)
yaml file, by calling its `to_yaml()` method:

```
conf = InvokeAIAppConfig(outdir='/tmp', port=8080)
print(conf.to_yaml())
```

# Legacy code removal and porting

This PR replaces Globals with the InvokeAIAppConfig system throughout,
and therefore removes the `globals.py` and `args.py` modules. It also
removes `generate` and the legacy CLI. ***The old CLI and web servers
are now gone.***

I have ported the functionality of the configuration script, the model
installer, and the merge and textual inversion scripts. The `invokeai`
command will now launch `invokeai-node-cli`, and `invokeai-web` will
launch the web server.

I have changed the continuous invocation tests to accommodate the new
command syntax in `invokeai-node-cli`. As a convenience function, you
can also pass invocations to `invokeai-node-cli` (or its alias
`invokeai`) on the command line as as standard input:

```
invokeai-node-cli "t2i --positive_prompt 'banana sushi' --seed 42"
invokeai < invocation_commands.txt
```
2023-05-18 13:37:09 -04:00
7ea995149e fixes to env parsing, textual inversion & help text
- Make environment variable settings case InSenSiTive:
  INVOKEAI_MAX_LOADED_MODELS and InvokeAI_Max_Loaded_Models
  environment variables will both set `max_loaded_models`

- Updated realesrgan to use new config system.

- Updated textual_inversion_training to use new config system.

- Discovered a race condition when InvokeAIAppConfig is created
  at module load time, which makes it impossible to customize
  or replace the help message produced with --help on the command
  line. To fix this, moved all instances of get_invokeai_config()
  from module load time to object initialization time. Makes code
  cleaner, too.

- Added `--from_file` argument to `invokeai-node-cli` and changed
  github action to match. CI tests will hopefully work now.
2023-05-18 10:48:23 -04:00
f9710dd6ed remove reference to legacy opt.hf_token, clean up whitespace in invokeai_configure 2023-05-17 20:39:00 -04:00
4e7dd7d3f6 ci: remove reference to Globals in a workflow 2023-05-17 20:26:26 -04:00
20ca9e1fc1 config: move 'CORS' settings to 'Web Server' in the docstring to match the actual category 2023-05-17 19:45:51 -04:00
8a8b09a953 api_app: rename web_config to app_config for consistency 2023-05-17 19:42:13 -04:00
9e4e386c9b web and formatting fixes
- remove non-existent import InvokeAIWebConfig
- fix workflow file formatting
- clean up whitespace
2023-05-17 19:12:03 -04:00
eca1e449a8 Merge branch 'lstein/global-configuration' of github.com:invoke-ai/InvokeAI into lstein/global-configuration 2023-05-17 15:23:21 -04:00
ffaadb9d05 reorder options in help text 2023-05-17 15:22:58 -04:00
8adff96e29 Merge branch 'main' into lstein/global-configuration 2023-05-17 14:37:09 -04:00
7593dc19d6 complete several steps needed to make 3.0 installable
- invokeai-configure updated to work with new config system
- migrate invokeai.init to invokeai.yaml during configure
- replace legacy invokeai with invokeai-node-cli
- add ability to run an invocation directly from invokeai-node-cli command line
- update CI tests to work with new invokeai syntax
2023-05-17 14:13:27 -04:00
b7c5a39685 make invokeai.yaml more hierarchical; fix list configuration bug 2023-05-17 12:19:19 -04:00
bd1b84f7d0 tell user to refresh page on image load error (#3425)
* refetch images list if error loading

* tell user to refresh instead of refetching

* unused import

* feat(ui): use `useAppToaster` to make toast

* fix(ui): clear selected/initial image on error

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-05-17 11:52:37 -04:00
eadfd239a8 update config script to work with new config system 2023-05-17 00:18:19 -04:00
8d75e50435 partial port of invokeai-configure 2023-05-16 01:50:01 -04:00
1d9c115225 feat(nodes): add low and high to RandomIntInvocation 2023-05-16 13:50:52 +10:00
30af20a056 ui: cleanup (#3418)
- tidy up a lot of cruft
- `sampler` --> `scheduler`
2023-05-16 15:27:12 +12:00
cc21fb216c chore(ui): clean up GalleryPanel 2023-05-16 10:43:26 +10:00
6fe62a2705 feat(ui): sampler --> scheduler 2023-05-16 10:40:26 +10:00
da87378713 chore(ui): regen api client 2023-05-16 10:39:40 +10:00
b6f5267385 chore(ui): clean up generationSlice 2023-05-16 10:21:18 +10:00
f9e78d3c64 chore(ui): clean up gallerySlice 2023-05-16 10:16:36 +10:00
b7b5bd1b46 chore(ui): clean up uiSlice 2023-05-16 09:57:19 +10:00
9a3727d3ad chore(ui): clean up systemSlice 2023-05-16 09:48:58 +10:00
d68c14516c chore(ui): clean up persist denylists 2023-05-16 09:46:03 +10:00
9f4d39aa42 chore(ui): clean up modelSlice 2023-05-16 09:45:49 +10:00
84b801d88f ui: restore canvas and upload functionality (#3414)
- refactor image uploading, fix init image upload button 
- refactor toast and hotkey hooks into logical components
- restore canvas save/download/copy/merge functionality
- clean up unused files and packages
- fix canvas rendering issue resulting from fractional stage coords
2023-05-16 02:23:39 +12:00
2fc70c509b Merge branch 'main' into feat/ui/fix-uploading 2023-05-16 02:20:59 +12:00
34fb1c4b19 make conditioning.py work with compel 1.1.5 (#3383)
This PR fixes the ValueError issue that was preventing all prompts from
working.
2023-05-15 09:46:04 -04:00
80bdd550cf Merge branch 'main' into lstein/bugfix/compel 2023-05-15 09:25:21 -04:00
7ef0d2aa35 merge with main 2023-05-15 09:07:17 -04:00
2359b92b46 chore(ui): tidy unused component ref 2023-05-15 22:58:15 +10:00
a404fb2d32 docs(ui): update PACKAGE_SCRIPTS.md 2023-05-15 22:49:28 +10:00
513eb11616 chore(ui): clean up unused files/packages 2023-05-15 22:48:06 +10:00
d2c9140e69 feat(ui): restore save/copy/download/merge functionality 2023-05-15 22:21:03 +10:00
d95fe5925a feat(ui): restore image post-upload actions
eg set init image if on img2img when uploading
2023-05-15 18:52:48 +10:00
835922ea8f fix(ui): floor canvas coords to prevent partial pixel offset rendering issues 2023-05-15 18:50:34 +10:00
e1e5266fc3 feat(ui): refactor base image uploading logic 2023-05-15 17:45:05 +10:00
5e4457445f feat(ui): make toast/hotkey into logical components 2023-05-15 15:25:27 +10:00
0221ca8f49 fix(ui): use cloned canvas for retrieving dataURL/Blobs 2023-05-15 13:54:30 +10:00
cf36e4029e fix(ui): fix syntax error in the logo component flexbox 2023-05-15 08:24:33 +10:00
c8a98a9a22 Merge branch 'main' into lstein/bugfix/compel 2023-05-14 14:43:18 -04:00
38ecca9362 Logging Improvements (#3401)
This PR improves the logging module a tad bit along with the
documentation.

**New Look:**


![WindowsTerminal_XaijwCqFpo](https://github.com/invoke-ai/InvokeAI/assets/54517381/49a97411-1927-4a49-80ff-f4d9665be55f)

## Usage

**General Logger**

InvokeAI has a module level logger. You can call it this way.

In this below example, you will use the default logger `InvokeAI` and
all your messages will be logged under that name.

```python

from invokeai.backend.util.logging import logger

logger.critical("CriticalMessage") // In Bold Red
logger.error("Info Message") // In Red
logger.warning("Info Message") // In Yellow
logger.info("Info Message") // In Grey 
logger.debug("Debug Message") // In Grey
```

Results:

```
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
```

**Custom Logger**

If you want to use a custom logger for your module, you can import it
the following way.

```python

from invokeai.backend.util.logging import logging
logger = logging.getLogger(name='Model Manager')

logger.critical("CriticalMessage") // In Bold Red
logger.error("Info Message") // In Red
logger.warning("Info Message") // In Yellow
logger.info("Info Message") // In Grey 
logger.debug("Debug Message") // In Grey
```

Results:

```
[12-05-2023 20]::[Model Manager]::CRITICAL --> This is an info message [In Bold Red]
[12-05-2023 20]::[Model Manager]::ERROR --> This is an info message [In Red]
[12-05-2023 20]::[Model Manager]::WARNING --> This is an info message [In Yellow]
[12-05-2023 20]::[Model Manager]::INFO --> This is an info message [In Grey]
[12-05-2023 20]::[Model Manager]::DEBUG --> This is an info message [In Grey]
```

**When to use custom logger?**

It is recommended to use a custom logger if your module is not a part of
base InvokeAI. For example: custom extensions / nodes.
2023-05-15 02:18:20 +12:00
c4681774a5 Merge branch 'main' into logging-facelift 2023-05-15 02:08:29 +12:00
050add58d2 fix getting conditionings 2023-05-14 12:20:54 +02:00
3d60c958c7 ui: commercial fixes (#3409)
minor commercial fixes
2023-05-14 20:44:06 +12:00
f5df150097 feat(ui): add callback to signal app is ready
needed for commercial
2023-05-14 18:42:15 +10:00
dac82adb5b fix(ui): make logo component non-selectable 2023-05-14 18:41:11 +10:00
b72c9787a9 Revert "comment out customer_attention_context"
This reverts commit 8f8cd90787.

Due to NameError: name 'options' is not defined
2023-05-14 00:37:55 -04:00
2623941d91 Merge branch 'main' into lstein/bugfix/compel 2023-05-13 22:23:59 -04:00
d3a7fea939 Revert "fix: Rework the layout of the parameters scrollbar"
This reverts commit 6f1fc397f7.
2023-05-14 11:45:08 +10:00
5a7b687c84 fix(ui): add missing packages 2023-05-14 11:45:08 +10:00
0020457fc7 fix(ui): tweak settings scheduler styling 2023-05-14 11:45:08 +10:00
658b556544 feat(ui): IAICustomSelect v2, implement for scheduler & model 2023-05-14 11:45:08 +10:00
37da0fc075 feat(ui): IAICustomSelect v1 2023-05-14 11:45:08 +10:00
6d3e8507cc fix(ui): fix "no image" fallbacks 2023-05-14 11:45:08 +10:00
0e9470503f fix: Rework the layout of the parameters scrollbar 2023-05-14 11:45:08 +10:00
d2ebc6741b feat: Add setting to hide / display schedulers 2023-05-14 11:45:08 +10:00
026d3260b4 Add Heun Karras Scheduler 2023-05-14 11:45:08 +10:00
1103ab2844 merge with main 2023-05-13 21:35:19 -04:00
11b2076b46 implement change to web_config suggested by ebr 2023-05-13 21:33:19 -04:00
78533714e3 Merge branch 'main' into logging-facelift 2023-05-14 09:07:51 +12:00
691e1bf829 Make debug messages cyan/blue 2023-05-14 09:06:57 +12:00
47a088d685 rehydrate selectedImage URL when results and uploads are fetched 2023-05-13 09:48:38 +10:00
63db3fc22f reduce queue check interval to 0.5s 2023-05-12 17:54:26 -04:00
ad0bb3f61a fix: queue error should not crash InvocationProcessor
1. if retrieving an item from the queue raises an exception, the
   InvocationProcessor thread crashes, but the API continues running in
   a non-functional state. This fixes the issue
2. when there are no items in the queue, sleep 1 second before checking
   again.
3. Also ensures the thread isn't crashed if an exception is raised from
   invoker, and emits the error event

Intentionally using base Exceptions because for now we don't know which
specific exception to expect.

Fixes (sort of)? #3222
2023-05-12 17:54:26 -04:00
8f8cd90787 comment out customer_attention_context 2023-05-12 13:59:00 -04:00
d796ea7bec feat: Logging Improvements 2023-05-13 02:13:49 +12:00
e5b7dd63e9 fix(nodes): temporarily disable librarygraphs
- Do not retrieve graph from DB until we resolve the issue of changing node schemas causing application to fail to start up due to invalid graphs
2023-05-12 22:33:49 +10:00
af060188bd Merge branch 'main' into lstein/bugfix/compel 2023-05-12 08:22:18 -04:00
4270e7ae25 Feat/ui/improve-language (#3399) 2023-05-12 23:32:50 +12:00
60a565d7de feat(ui): use chakra menu for theme changer 2023-05-12 20:04:29 +10:00
78cf70eaad fix(ui): tweak lang picker style 2023-05-12 20:04:10 +10:00
eebaa50710 fix(ui): fix language picker tooltip 2023-05-12 19:52:21 +10:00
7d582553f2 feat(ui): use chakra menu for language picker 2023-05-12 19:50:34 +10:00
4d6eea7e81 feat(ui): store language in redux 2023-05-12 19:35:03 +10:00
f44593331d ui: misc fixes (#3398)
- do not show canvas intermediates in gallery
- do not show progress image in uploads gallery category
- use custom dark mode `localStorage` key (prevents collision with
commercial)
- use variable font (reduce bundle size by factor of 10)
- change how custom headers are used
- use style injection for building package
- fix tab icon sizes
2023-05-12 21:00:47 +12:00
3d9ecbf3c7 fix(ui): add missing package 2023-05-12 18:55:59 +10:00
032aa1d59c fix(ui): excise most zIndexs
our stacking contexts are accurate, `zIndex` isn't needed
2023-05-12 18:50:54 +10:00
35e0863bdb fix(ui): fix tab icon sizes 2023-05-12 17:56:18 +10:00
14070d674e build(ui): add style injection plugin
when building for package, CSS is all in JS files. when used as a package, it is then injected into the page. bit of a hack to missing CSS in commercial product
2023-05-12 17:56:18 +10:00
108ce06c62 feat(ui): change custom header to be a prop instead of children 2023-05-12 17:56:18 +10:00
da364f3444 feat(ui): use variable font
reduces package build's CSS by an order of magnitude
2023-05-12 17:56:18 +10:00
df5ba75c14 feat(ui): use custom dark mode localStorage key 2023-05-12 17:56:18 +10:00
e4fb9cb33f chore(ui): regen api client 2023-05-12 17:56:18 +10:00
65b527eb20 fix(ui): do not show progress images in uploads gallery category 2023-05-12 17:56:18 +10:00
7dc9d18052 fix(ui): do not show intermediates uploads in gallery 2023-05-12 17:56:18 +10:00
5013a4b9f3 feat(ui): expand config options (#3393)
now may disable individual SD features eg Noise, Variation, etc - stuff
which is not ready for consumption in commercial.
2023-05-12 16:10:17 +12:00
f929359322 Merge branch 'main' into feat/ui/expand-config 2023-05-12 16:06:31 +12:00
6522c71971 feat(nodes): add RandomIntInvocation (#3390)
just outputs a single random int
2023-05-12 16:06:06 +12:00
9c1e65f3a3 Merge branch 'main' into feat/nodes/add-randomintinvocation 2023-05-12 15:56:41 +12:00
ebec200ba6 Remove unused import 2023-05-12 13:56:02 +10:00
e559730b6e feat(nodes): add w/h to latents outputs (#3389)
This reduces the number of nodes needed when working with latents (ie
fewer plain integer value nodes)

Also correct a few mistakes in the fields
2023-05-12 15:40:46 +12:00
0acb8ed85d Merge branch 'main' into feat/nodes/add-w-h-latentsoutput 2023-05-12 15:23:29 +12:00
8c1c9cd702 Merge branch 'main' into feat/nodes/add-randomintinvocation 2023-05-12 15:21:49 +12:00
0ece4686aa fix(nodes): remove Optionals on ImageOutputs (#3392) 2023-05-12 15:21:42 +12:00
af95cef7f9 Merge branch 'main' into fix/nodes/fix-imageoutput-optionals 2023-05-12 15:08:19 +12:00
1eca7a918a feat(ui): make core parameters layout consistent (#3394) 2023-05-12 15:08:07 +12:00
9e6b958023 Merge branch 'main' into feat/ui/consistent-param-layout 2023-05-12 15:06:16 +12:00
f7b99d93ae docs(ui): update ui readme (#3396) 2023-05-12 15:05:55 +12:00
85d03dcd90 Merge branch 'main' into docs/ui/update-ui-readme 2023-05-12 15:04:12 +12:00
032555bcfe fix(model manager): fix string formatting error on model checksum timer (#3397)
The error occurs when loading a model for the first time. (or after
removing its checksum file, probably.)
2023-05-12 15:04:01 +12:00
4caa1f19b2 fix(model manager): fix string formatting error on model checksum timer 2023-05-11 19:06:02 -07:00
95d4bd3012 Merge branch 'lstein/bugfix/compel' of github.com:invoke-ai/InvokeAI into lstein/bugfix/compel 2023-05-11 21:13:29 -04:00
037078c8ad make InvokeAIDiffuserComponent.custom_attention_control a classmethod 2023-05-11 21:13:18 -04:00
6de2f66b50 docs(ui): update ui readme 2023-05-12 11:11:59 +10:00
cd7b248eda Add UniPC / Euler Karras / DPMPP_2 Karras / DEIS / DDPM Schedulers (#3388)
**Features:**

- Add UniPC Scheduler
- Add Euler Karras Scheduler
- Add DPMPP_2 Karras Scheduler
- Add DEIS Scheduler
- Add DDPM Scheduler

**Other:**

- Renamed schedulers to their accurate names: _a = Ancestral, _k =
Karras
- Fix scheduler not defaulting correctly to DDIM.
- Code split SCHEDULER_MAP so its consistently loaded from the same
place.

**Known Bugs:**

- dpmpp_2s not working in img2img for denoising values < 0.8 ==> // This
seems to be an upstream bug. I've disabled it in img2img and canvas
until the upstream bug is fixed.
https://github.com/huggingface/diffusers/issues/1866
2023-05-12 09:06:22 +12:00
6d8c077f4e Merge branch 'main' into unipc-sched 2023-05-12 05:59:13 +12:00
97127e560e Disable dpmpp_2s in img2img & unifiedCanvas
... until upstream bug is fixed.
2023-05-12 04:51:58 +12:00
27dc07d95a Set zero eta by default(fix ddim scheduler error) 2023-05-11 18:49:27 +03:00
f7dc171c4f Rename default schedulers across the app 2023-05-12 03:44:20 +12:00
4b957edfec Add DDPM Scheduler 2023-05-12 03:18:34 +12:00
46ca7718d9 Add DEIS Scheduler 2023-05-12 03:10:30 +12:00
b928d7a6e6 Change scheduler names to be accurate
_a = Ancestral
_k = Karras
2023-05-12 02:59:43 +12:00
8a836247c8 Add DPMPP Single, Euler Karras and DPMPP2 Multi Karras Schedulers 2023-05-12 02:23:33 +12:00
95c3644564 fix it again 2023-05-12 00:10:39 +10:00
799cd07174 feat(ui): make core parameters layout consistent 2023-05-11 22:45:53 +10:00
9af385468d feat(ui): expand config options
now may disable individual SD features eg Noise, Variation, etc - stuff which is not ready for consumption in commercial.
2023-05-11 22:42:13 +10:00
3487388788 Merge branch 'unipc-sched' of https://github.com/blessedcoolant/InvokeAI into unipc-sched 2023-05-12 00:40:24 +12:00
9a383e456d Codesplit SCHEDULER_MAP for reusage 2023-05-12 00:40:03 +12:00
805f9f8f4a Merge branch 'main' into unipc-sched 2023-05-12 00:24:55 +12:00
52aa0c9bbd ui: miscellaneous fixes (#3386) 2023-05-12 00:21:29 +12:00
7f5f4689cc fix(ui): clear progress image on cancel 2023-05-11 22:20:37 +10:00
a3f81f4b98 fix(ui): fix results not displaying
- fix for commercial product
2023-05-11 22:20:37 +10:00
15c59e606f feat(ui): add spinner to gallery progress images
- otherwise you may think you can click it but you cannot
2023-05-11 22:20:37 +10:00
40d4cabecd feat(ui): improve image overlay 2023-05-11 22:20:37 +10:00
3493c8119b feat(ui): improve image preview css and fallback 2023-05-11 22:20:30 +10:00
c1e7460d39 Merge branch 'main' into unipc-sched 2023-05-12 00:11:09 +12:00
3ffff023b2 Add missing key to scheduler_map
It was breaking coz the sampler was not being reset. So needs a key on each. Will simplify this later.
2023-05-12 00:08:50 +12:00
f9384be59b fix(ui): fix init image causing overflow 2023-05-11 20:55:30 +10:00
6cf308004a fix(nodes): remove Optionals on ImageOutputs 2023-05-11 20:54:57 +10:00
d1029138d2 Default to DDIM if scheduler is missing 2023-05-11 22:54:35 +12:00
06b5800d28 Add UniPC Scheduler 2023-05-11 22:43:18 +12:00
483f2ccb56 feat(nodes): add RandomIntInvocation
just outputs a single random int
2023-05-11 20:33:32 +10:00
93ced0bec6 feat(nodes): add w/h to latents outputs
This reduces the number of nodes needed when working with latents (ie fewer plain integer value nodes)

Also correct a few mistakes in the fields
2023-05-11 20:32:55 +10:00
4333852c37 fix(nodes): fix missing context arg in LatentsToLatents 2023-05-11 19:28:42 +10:00
3baa230077 Merge branch 'main' into lstein/bugfix/compel 2023-05-11 00:50:45 -04:00
9e594f9018 pad conditioning tensors to same length
fixes crash when prompt length is greater than 75 tokens
2023-05-11 00:34:15 -04:00
b0c41b4828 filter our websocket errors (#3382)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2023-05-11 01:58:40 +00:00
e0d6946b6b fix(nodes): fix metadata test
- `progress_images` is no longer a parameter
- `seamless` needs to be reworked as a model config, removed as a param
2023-05-11 11:55:51 +10:00
bf7ea8309f fix(ui): change tab to img2img when selected initial image 2023-05-11 11:55:51 +10:00
54b65f725f fix(ui): rescale canvas on gallery resize 2023-05-11 11:55:51 +10:00
8ef49c2640 fix(ui): fix canvas img2img if no init image selected 2023-05-11 11:55:51 +10:00
f488b1a7f2 fix(nodes): fix usage of Optional 2023-05-11 11:55:51 +10:00
d2edb7c402 build(ui): add yalc to gitignore 2023-05-11 11:55:51 +10:00
f0a3f07b45 feat(ui): antialias progress images 2023-05-11 11:55:51 +10:00
b42b630583 fix(ui): h/w disabled bug 2023-05-11 11:55:51 +10:00
31a78d571b feat(ui): canvas antialiasing 2023-05-11 11:55:51 +10:00
fdc2232ea0 feat(ui): progress images in gallery and viewer 2023-05-11 11:55:51 +10:00
e94d0b2d40 fix(ui): fix janky gallery image delete 2023-05-11 11:55:51 +10:00
75ccbaee9c fix(ui): disable invoke button as soon as pressed 2023-05-11 11:55:51 +10:00
2848c8397c fix(ui): fix missing images on reload issue
- Mainly an issue for commercial due to incomplete metadata handling
2023-05-11 11:55:51 +10:00
fe8b5193de feat(ui): half-baked use all parameters
until we have a better system for metadata, this will remain half-baked
2023-05-11 11:55:51 +10:00
3d1470399c fix(ui): fix metadataviewer styling 2023-05-11 11:55:51 +10:00
fcf9c63049 fix(ui): fix copying image link 2023-05-11 11:55:51 +10:00
7bfb5640ad cleanup(ui): Remove unused vars + minor bug fixes 2023-05-11 11:55:51 +10:00
15e57e3a3d fix(ui): duplicate gallery in nodes editor 2023-05-11 11:55:51 +10:00
279468c0e8 feat(ui): restore tab names 2023-05-11 11:55:51 +10:00
c565812723 feat(ui): organize parameters panels 2023-05-11 11:55:51 +10:00
ec6c8e2a38 feat(ui): wip layout 2023-05-11 11:55:51 +10:00
77f2690711 fix(ui): remove duplicate gallery 2023-05-11 11:55:51 +10:00
c4b3a24ed7 feat(ui): revert tabs to txt2img/img2img 2023-05-11 11:55:51 +10:00
33c69359c2 feat(ui): add IAICollapse for parameters 2023-05-11 11:55:51 +10:00
864f4bb4af feat(ui): wip img2img layouting 2023-05-11 11:55:51 +10:00
5365f42a04 feat(ui): wip layouting 2023-05-11 11:55:51 +10:00
3dc60254b9 feat(ui): support collect nodes 2023-05-11 11:55:51 +10:00
027a8562d7 fix(ui): default node model selection 2023-05-11 11:55:51 +10:00
34f3a0f0e3 feat(nodes): improve default model choosing output 2023-05-11 11:55:51 +10:00
d0bac1675e fix(nodes): fix ImageOutput Config 2023-05-11 11:55:51 +10:00
4e56c962f4 fix(nodes): fix infill docstrings 2023-05-11 11:55:51 +10:00
4ef0e43759 fix(nodes): remove dataURL invocation 2023-05-11 11:55:51 +10:00
6945d10297 chore(ui): regen api client 2023-05-11 11:55:51 +10:00
4d6cef7ac8 fix(ui): fix types bug 2023-05-11 11:55:51 +10:00
a7786d5ff2 fix(nodes): restore seamless to TextToLatents 2023-05-11 11:55:51 +10:00
6c1de975d9 feat(nodes): add infill nodes 2023-05-11 11:55:51 +10:00
a1079e455a feat(nodes): cleanup unused params, seed generation 2023-05-11 11:55:51 +10:00
5457c7f069 fix(ui): use lodash-es instead of lodash 2023-05-11 11:55:51 +10:00
b8c1a3f96c chore(ui): remove unused babelrc & npm script 2023-05-11 11:55:51 +10:00
cee8e85f76 chore(ui): bump redux-remember 2023-05-11 11:55:51 +10:00
09f166577e feat(ui): migrate to redux-remember 2023-05-11 11:55:51 +10:00
bcc21531fb feat(ui): update for InfillInvocation 2023-05-11 11:55:51 +10:00
da4eacdffe feat(nodes): add InfillInvocation 2023-05-11 11:55:51 +10:00
6102e560ba feat(nodes): add LatentsToImage node (VAE encode) 2023-05-11 11:55:51 +10:00
ff3aa57117 feat(ui): fix endless gallery scroll for single col layout 2023-05-11 11:55:51 +10:00
49db6f4fac fix(nodes): fix trivial typing issues 2023-05-11 11:55:51 +10:00
20f6a597ab fix(nodes): add MetadataColorField 2023-05-11 11:55:51 +10:00
04c453721c feat(ui): tweak gallery loading indicator 2023-05-11 11:55:51 +10:00
350ffecc1f feat(ui): endless gallery scroll 2023-05-11 11:55:51 +10:00
b0557aa16b fix(ui): fix currentimagepreview not working for uploads 2023-05-11 11:55:51 +10:00
1c9429a6ea feat(ui): wip canvas 2023-05-11 11:55:51 +10:00
206e6b1730 feat(nodes): wip inpaint node 2023-05-11 11:55:51 +10:00
357cee2849 fix(nodes): fix cfg scale min value 2023-05-11 11:55:51 +10:00
0b49997bb6 feat(nodes): allow uploaded images to be any ImageType (eg intermediates) 2023-05-11 11:55:51 +10:00
5e09dd380d Revert "feat(nodes): free gpu mem after invocation"
This reverts commit 99cb33f477306d5dcc455efe04053ce41b8d85bd.
2023-05-11 11:55:51 +10:00
c7303adb0d feat(ui): fix generation mode logic 2023-05-11 11:55:51 +10:00
ed1f096a6f feat(ui): wip canvas migration 4 2023-05-11 11:55:51 +10:00
6ab5d28cf3 feat(ui): wip canvas migration, createListenerMiddleware 2023-05-11 11:55:51 +10:00
a75148cb16 feat(nodes): free gpu mem after invocation 2023-05-11 11:55:51 +10:00
f7bbc4004a feat(ui): wip canvas nodes migration 3 2023-05-11 11:55:51 +10:00
cee21ca082 feat(ui): wip canvas nodes migration 2 2023-05-11 11:55:51 +10:00
08ec12b391 feat(ui): wip canvas nodes migration 2023-05-11 11:55:51 +10:00
ff5e2a9a8c chore(ui): regen api client 2023-05-11 11:55:51 +10:00
e0b9b5cc6c feat(nodes): add dataURL to image node 2023-05-11 11:55:51 +10:00
aca4770481 fixed compel.py as requested 2023-05-10 21:40:44 -04:00
5d5157fc65 make conditioning.py work with compel 1.1.5 2023-05-10 18:08:33 -04:00
fb6ef61a4d change path for locale (#3381)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2023-05-10 10:30:17 -04:00
ee24ad7b13 fix(nodes): fix broken docs routes 2023-05-10 08:28:17 -04:00
f8e90ba3f0 feat(nodes): add ui build static route 2023-05-10 08:28:17 -04:00
ad0b70ca23 fix(nodes): fix #3306 (#3377)
Check if the cache has the object before deleting it.
2023-05-10 17:39:45 +12:00
7dfa135b2c fix(nodes): fix #3306
Check if the cache has the object before deleting it.
2023-05-10 15:29:10 +10:00
beeaa05658 Update dependencies to get deterministic image generation behavior (main branch) (#3354)
This PR updates to `xformers ~= 0.0.19` and `torch ~= 2.0.0`, which
together seem to solve the non-deterministic image generation issue that
was previously seen with earlier versions of `xformers`.
2023-05-10 00:10:51 -04:00
6b6d654f60 Merge branch 'main' into enhance/update-dependencies 2023-05-09 23:56:46 -04:00
853c83d0c2 surface detail field for 403 errors 2023-05-09 12:40:19 +10:00
1809990ed4 if backend returns an error, show it in toast 2023-05-09 11:09:36 +10:00
79d49853d2 use websocket transport first for socket.io 2023-05-09 11:01:02 +10:00
bd0ad59c27 bump compel version 2023-05-07 15:22:46 -04:00
cce40acba5 Merge branch 'enhance/update-dependencies' of github.com:invoke-ai/InvokeAI into enhance/update-dependencies 2023-05-07 15:22:31 -04:00
bc9491ab69 bump compel version 2023-05-07 15:21:24 -04:00
f28632980d Merge branch 'main' into lstein/global-configuration 2023-05-07 07:52:46 -04:00
b909bac0dc Merge branch 'main' into enhance/update-dependencies 2023-05-07 21:44:43 +12:00
42d938fda5 remove debugging statement 2023-05-06 23:54:11 -04:00
8f80ba9520 update dependencies to get deterministic image generation 2023-05-06 23:09:24 -04:00
25ce47c44f remove reference to globals in compel.py 2023-05-06 22:49:35 -04:00
afd2e32092 Merge branch 'main' into lstein/global-configuration 2023-05-06 21:20:25 -04:00
742ed19d66 add missing config module 2023-05-04 01:20:30 -04:00
29c2ada23c add test for the configuration module 2023-05-04 00:45:52 -04:00
e4196bbe5b adjust non-app modules to use new config system 2023-05-04 00:43:51 -04:00
15ffb53e59 remove globals, args, generate and the legacy CLI 2023-05-03 23:36:51 -04:00
90054ddf0d use InvokeAISettings for app-wide configuration 2023-05-03 22:30:30 -04:00
9ecca13229 Add Convert Model Endpoint 2023-04-08 18:05:21 -04:00
455 changed files with 9319 additions and 11984 deletions

View File

@ -80,12 +80,7 @@ jobs:
uses: actions/checkout@v3
- name: set test prompt to main branch validation
if: ${{ github.ref == 'refs/heads/main' }}
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
- name: set test prompt to Pull Request validation
if: ${{ github.ref != 'refs/heads/main' }}
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
run:echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
uses: actions/setup-python@v4
@ -105,12 +100,6 @@ jobs:
id: run-pytest
run: pytest
- name: set INVOKEAI_OUTDIR
run: >
python -c
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
>> ${{ matrix.github-env }}
- name: run invokeai-configure
id: run-preload-models
env:
@ -129,15 +118,20 @@ jobs:
HF_HUB_OFFLINE: 1
HF_DATASETS_OFFLINE: 1
TRANSFORMERS_OFFLINE: 1
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
run: >
invokeai
--no-patchmatch
--no-nsfw_checker
--from_file ${{ env.TEST_PROMPTS }}
--precision=float32
--always_use_cpu
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }}
- name: Archive results
id: archive-results
env:
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
uses: actions/upload-artifact@v3
with:
name: results

2
.gitignore vendored
View File

@ -201,6 +201,8 @@ checkpoints
# If it's a Mac
.DS_Store
invokeai/frontend/web/dist/*
# Let the frontend manage its own gitignore
!invokeai/frontend/web/*

View File

@ -247,8 +247,8 @@ class InvokeAiInstance:
pip[
"install",
"--require-virtualenv",
"torch",
"torchvision",
"torch~=2.0.0",
"torchvision>=0.14.1",
"--force-reinstall",
"--find-links" if find_links is not None else None,
find_links,

View File

@ -7,7 +7,6 @@ from typing import types
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
@ -42,17 +41,8 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
# TO DO: Use the config to select the logger rather than use the default
# invokeai logging module
logger.info(f"Internet connectivity is {Globals.internet_available}")
logger.info(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
@ -72,7 +62,6 @@ class ApiDependencies:
services = InvocationServices(
model_manager=get_model_manager(config,logger),
events=events,
logger=logger,
latents=latents,
images=images,
metadata=metadata,
@ -85,6 +74,8 @@ class ApiDependencies:
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
configuration=config,
logger=logger,
)
create_system_graphs(services.graph_library)

View File

@ -83,7 +83,7 @@ async def get_thumbnail(
status_code=201,
)
async def upload_image(
file: UploadFile, request: Request, response: Response
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@ -99,21 +99,21 @@ async def upload_image(
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
image_type, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name, True
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and Kent Keirsey (https://github.com/hipsterusername)
import shutil
import asyncio
import os
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter, HTTPException
@ -47,10 +47,8 @@ class CreateModelResponse(BaseModel):
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
@ -124,6 +122,95 @@ async def delete_model(model_name: str) -> None:
logger.error(f"Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# TODO: Refactor these support functions below to live somewhere more appropriate
def get_model_info(model_name: str):
model_info = ApiDependencies.invoker.services.model_manager.model_info(
model_name=model_name
)
if not model_info:
raise HTTPException(status_code=404, detail=f"Unable to retrieve model info for '{model_name}'")
return model_info
def ckpt_validate(model_info: dict, model_name: str):
if "weights" not in model_info:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' is not a valid checkpoint model")
def get_paths(model: ConversionRequest, root: Path) -> tuple:
model_info = get_model_info(model.name)
ckpt_path = Path(model_info.weights)
config_path = Path(model_info.config)
if not ckpt_path.is_absolute():
ckpt_path = Path(root, ckpt_path)
if config_path and not config_path.is_absolute():
config_path = Path(root, config_path)
return ckpt_path, config_path
def get_diffusers_path(convert_request: ConversionRequest, model_name: str) -> Path:
if convert_request.save_location == "root":
diffusers_path = Path(global_converted_ckpts_dir(), f"{model_name}_diffusers")
elif convert_request.save_location == "custom" and convert_request.save_location is not None:
diffusers_path = Path(convert_request.save_location, f"{model_name}_diffusers")
else:
raise ValueError("Invalid save_location value")
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
return diffusers_path
@models_router.post(
"/{model_to_convert}",
operation_id="convert_model",
responses={
200: {
"model_response": "Model converted successfully.",
}
},
)
async def convert_model(convert_request: ConversionRequest) -> ConvertedModelResponse:
"""Convert Model"""
opt=Args()
args = opt.parse_args()
# Set the root directory for static files and relative paths
args.root_dir = os.path.expanduser(args.root_dir or "..")
if not os.path.isabs(args.outdir):
args.outdir = os.path.join(args.root_dir, args.outdir)
# normalize the config directory relative to root
if not os.path.isabs(opt.conf):
opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
model_info = get_model_info(convert_request.name)
ckpt_validate(model_info, convert_request.name)
ckpt_path, original_config_file = get_paths(convert_request, Globals.root)
diffusers_path = get_diffusers_path(convert_request, convert_request.name)
ApiDependencies.invoker.services.model_manager.convert_and_import(
ckpt_path,
diffusers_path,
model_name=convert_request.name,
model_description=model_info.description,
vae=None,
original_config_file=original_config_file,
commit_to_conf=opt.conf,
)
model_info = get_model_info(convert_request.name)
convert_response = ConvertedModelResponse(name=f"{convert_request.name}_diffusers", info=model_info)
print(f">> Model Converted: {convert_request.name}")
return convert_response
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):

View File

@ -13,11 +13,11 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
@ -33,30 +33,25 @@ app.add_middleware(
middleware_id=event_handler_id,
)
# Add CORS
# TODO: use configuration for this
origins = []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
socket_io = SocketIO(app)
config = {}
# initialize config
# this is a module global
app_config = InvokeAIAppConfig()
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
config = Args()
config.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)
ApiDependencies.initialize(
config=config, event_handler_id=event_handler_id, logger=logger
config=app_config, event_handler_id=event_handler_id, logger=logger
)
@ -126,7 +121,6 @@ app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
@ -144,17 +138,16 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico",
)
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
def invoke_api():
# Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

View File

@ -285,3 +285,19 @@ class DrawExecutionGraphCommand(BaseCommand):
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class SortedHelpFormatter(argparse.HelpFormatter):
def _iter_indented_subactions(self, action):
try:
get_subactions = action._get_subactions
except AttributeError:
pass
else:
self._indent()
if isinstance(action, argparse._SubParsersAction):
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
yield subaction
else:
for subaction in get_subactions():
yield subaction
self._dedent()

View File

@ -11,9 +11,10 @@ from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
import invokeai.backend.util.logging as logger
from ...backend import ModelManager, Globals
from ...backend import ModelManager
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
from ..services.invocation_services import InvocationServices
# singleton object, class variable
completer = None
@ -131,13 +132,13 @@ class Completer(object):
readline.redisplay()
self.linebuffer = None
def set_autocompleter(model_manager: ModelManager) -> Completer:
def set_autocompleter(services: InvocationServices) -> Completer:
global completer
if completer:
return completer
completer = Completer(model_manager)
completer = Completer(services.model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
@ -153,7 +154,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(Globals.root, ".invoke_history")
histfile = Path(services.configuration.root_dir / ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)

View File

@ -4,13 +4,14 @@ import argparse
import os
import re
import shlex
import sys
import time
from typing import (
Union,
get_type_hints,
)
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
@ -19,8 +20,7 @@ from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
@ -34,7 +34,7 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
from .services.config import get_invokeai_config
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -64,7 +64,7 @@ def add_invocation_args(command_parser):
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
def exit(*args, **kwargs):
raise InvalidArgs
@ -189,24 +189,25 @@ def invoke_all(context: CliContext):
def invoke_cli():
config = Args()
config.parse_args()
# this gets the basic configuration
config = get_invokeai_config()
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument('commands',nargs='*')
invocation_commands = parser.parse_args().commands
# get the optional file to read commands from.
# Simplest is to use it for STDIN
if infile := config.from_file:
sys.stdin = open(infile,"r")
model_manager = get_model_manager(config,logger=logger)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
set_autocompleter(model_manager)
events = EventServiceBase()
output_folder = config.output_path
metadata = PngMetadataService()
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -226,6 +227,7 @@ def invoke_cli():
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
@ -241,9 +243,17 @@ def invoke_cli():
# print(services.session_manager.list())
context = CliContext(invoker, session, parser)
set_autocompleter(services)
while True:
command_line_args_exist = len(invocation_commands) > 0
done = False
while not done:
try:
if command_line_args_exist:
cmd_input = invocation_commands.pop(0)
done = len(invocation_commands) == 0
else:
cmd_input = input("invoke> ")
except (KeyboardInterrupt, EOFError):
# Ctrl-c exits
@ -368,6 +378,9 @@ def invoke_cli():
invoker.services.logger.warning('Invalid command, use "help" to list commands')
continue
except ValidationError:
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
except SessionError:
# Start a new session
invoker.services.logger.warning("Session error: creating a new session")

View File

@ -3,12 +3,12 @@
from typing import Literal, Optional
import numpy as np
import numpy.random
from pydantic import Field
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
BaseInvocationOutput,
)
@ -50,11 +50,11 @@ class RandomRangeInvocation(BaseInvocation):
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field(
seed: int = Field(
ge=0,
le=np.iinfo(np.int32).max,
description="The seed for the RNG",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> IntCollectionOutput:

View File

@ -16,8 +16,6 @@ from compel.prompt_parser import (
Fragment,
)
from invokeai.backend.globals import Globals
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
@ -100,9 +98,10 @@ class CompelInvocation(BaseInvocation):
# TODO: support legacy blend?
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
conjunction = Compel.parse_prompt_string(prompt_str)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if getattr(Globals, "log_tokenization", False):
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)

View File

@ -1,15 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial
from typing import Literal, Optional, Union
from typing import Literal, Optional, Union, get_args
import numpy as np
from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@ -17,7 +19,8 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
@ -44,15 +47,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
@ -148,7 +149,6 @@ class ImageToImageInvocation(TextToImageInvocation):
self.image.image_type, self.image.image_name
)
)
mask = None
if self.fit:
image = image.resize((self.width, self.height))
@ -165,7 +165,6 @@ class ImageToImageInvocation(TextToImageInvocation):
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
@ -197,7 +196,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -205,6 +203,17 @@ class InpaintInvocation(ImageToImageInvocation):
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
inpaint_replace: float = Field(
default=0.0,
ge=0.0,

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional
import numpy
@ -32,14 +33,12 @@ class ImageOutput(BaseInvocationOutput):
# fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {
"required": ["type", "image", "width", "height", "mode"]
}
schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
@ -54,7 +53,6 @@ def build_image_output(
image=image_field,
width=image.width,
height=image.height,
mode=image.mode,
)

View File

@ -0,0 +1,233 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
InvocationContext,
)
def infill_methods() -> list[str]:
methods = [
"tile",
"solid",
]
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
return methods
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
# Skip patchmatch if patchmatch isn't available
if not PatchMatch.patchmatch_available():
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def get_tile_images(image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
def tile_fill_missing(
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape(
(
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
else:
raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)

View File

@ -1,11 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional
from typing import Literal, Optional, Union
import einops
from pydantic import BaseModel, Field
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -13,7 +15,9 @@ from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
@ -37,41 +41,55 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
#fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@ -102,17 +120,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
@ -131,9 +145,7 @@ class NoiseInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
return build_noise_output(latents_name=name, latents=noise)
# Text to image
@ -149,11 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# Schema customisation
@ -218,7 +229,7 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data
@ -250,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -260,6 +269,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
@ -271,10 +284,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
},
}
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
@ -287,7 +296,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size
@ -295,11 +304,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
@ -315,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
return build_latents_output(latents_name=name, latents=result_latents)
# Latent to image
@ -384,8 +387,8 @@ class ResizeLatentsInvocation(BaseInvocation):
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@ -402,7 +405,7 @@ class ResizeLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
return build_latents_output(latents_name=name, latents=resized_latents)
class ScaleLatentsInvocation(BaseInvocation):
@ -413,8 +416,8 @@ class ScaleLatentsInvocation(BaseInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@ -432,4 +435,48 @@ class ScaleLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
return build_latents_output(latents_name=name, latents=resized_latents)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
model: str = Field(default="", description="The model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {"model": "model"},
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = model.non_noised_latents_from_image(
image_tensor,
device=model._model_group.device_for(model.unet),
dtype=model.unet.dtype,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -3,8 +3,14 @@
from typing import Literal
from pydantic import BaseModel, Field
import numpy as np
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
class MathInvocationConfig(BaseModel):
@ -21,19 +27,21 @@ class MathInvocationConfig(BaseModel):
class IntOutput(BaseInvocationOutput):
"""An integer output"""
#fmt: off
# fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
#fmt: on
# fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers"""
#fmt: off
# fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
@ -41,11 +49,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
"""Subtracts two numbers"""
#fmt: off
# fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
@ -53,11 +62,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
"""Multiplies two numbers"""
#fmt: off
# fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
@ -65,11 +75,26 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
class DivideInvocation(BaseInvocation, MathInvocationConfig):
"""Divides two numbers"""
#fmt: off
# fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
# fmt: off
type: Literal["rand_int"] = "rand_int"
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
# fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=np.random.randint(self.low, self.high))

View File

@ -4,10 +4,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
else:
if model_name and not model_manager.valid_model(model_name):
default_model_name = model_manager.default_model()
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
model = model_manager.get_model()
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
else:
model = model_manager.get_model(model_name)
return model

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import Optional, Tuple
from pydantic import BaseModel, Field
@ -27,3 +27,13 @@ class ImageField(BaseModel):
class Config:
schema_extra = {"required": ["image_type", "image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)

View File

@ -0,0 +1,521 @@
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
'''Invokeai configuration system.
Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that
has a top-level key of "InvokeAI" and subheadings for each of the
categories returned by `invokeai --help`. The file looks like this:
[file: invokeai.yaml]
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
embedding_dir: embeddings
lora_dir: loras
autoconvert_dir: null
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
Models:
model: stable-diffusion-1.5
embeddings: true
Memory/Performance:
xformers_enabled: false
sequential_guidance: false
precision: float16
max_loaded_models: 4
always_use_cpu: false
free_gpu_mem: false
Features:
nsfw_checker: true
restore: true
esrgan: true
patchmatch: true
internet_available: true
log_tokenization: false
Web Server:
host: 127.0.0.1
port: 8081
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf)
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
It is also possible to set a value at initialization time. This value
has highest priority.
conf = InvokeAIAppConfig(xformers_enabled=True)
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in:
export INVOKEAI_port=8080
Order of precedence (from highest):
1) initialization options
2) command line options
3) environment variable options
4) config file options
5) pydantic defaults
Typical usage:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.invocations.generate import TextToImageInvocation
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig()
print(conf.nsfw_checker)
# get the text2image invocation and print its step value
text2image = TextToImageInvocation()
print(text2image.steps)
Computed properties:
The InvokeAIAppConfig object has a series of properties that
resolve paths relative to the runtime root directory. They each return
a Path object:
root_path - path to InvokeAI root
output_path - path to default outputs directory
model_conf_path - path to models.yaml
conf - alias for the above
embedding_path - path to the embeddings directory
lora_path - path to the LoRA directory
In most cases, you will want to create a single InvokeAIAppConfig
object for the entire application. The get_invokeai_config() function
does this:
config = get_invokeai_config()
print(config.root)
# Subclassing
If you wish to create a similar class, please subclass the
`InvokeAISettings` class and define a Literal field named "type",
which is set to the desired top-level name. For example, to create a
"InvokeBatch" configuration, define like this:
class InvokeBatch(InvokeAISettings):
type: Literal["InvokeBatch"] = "InvokeBatch"
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
This will now read and write from the "InvokeBatch" section of the
config file, look for environment variables named INVOKEBATCH_*, and
accept the command-line arguments `--node_count` and `--cpu_count`. The
two configs are kept in separate sections of the config file:
# invokeai.yaml
InvokeBatch:
Resources:
node_count: 1
cpu_count: 8
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
'''
import argparse
import pydoc
import typing
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml')
LEGACY_INIT_FILE = Path('invokeai.init')
# This global stores a singleton InvokeAIAppConfig configuration object
global_config = None
class InvokeAISettings(BaseSettings):
'''
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
'''
initconf : ClassVar[DictConfig] = None
argparse_groups : ClassVar[Dict] = {}
def parse_args(self, argv: list=sys.argv[1:]):
parser = self.get_parser()
opt, _ = parser.parse_known_args(argv)
for name in self.__fields__:
if name not in self._excluded():
setattr(self, name, getattr(opt,name))
def to_yaml(self)->str:
"""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__
type = get_args(get_type_hints(cls)['type'])[0]
field_dict = dict({type:dict()})
for name,field in self.__fields__.items():
if name in cls._excluded():
continue
category = field.field_info.extra.get("category") or "Uncategorized"
value = getattr(self,name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
conf = OmegaConf.create(field_dict)
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
if 'type' in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
else:
settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
initconf = cls.initconf.get(settings_stanza) \
if cls.initconf and settings_stanza in cls.initconf \
else OmegaConf.create()
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = dict()
for key,value in os.environ.items():
upcase_environ[key.upper()] = value
fields = cls.__fields__
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
current_default = field.default
category = field.field_info.extra.get("category","Uncategorized")
env_name = env_prefix + '_' + name
if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name)
if env_name.upper() in upcase_environ:
field.default = upcase_environ[env_name.upper()]
cls.add_field_argument(parser, name, field)
field.default = current_default
@classmethod
def cmd_name(self, command_field: str='type')->str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
return 'Uncategorized'
@classmethod
def get_parser(cls)->ArgumentParser:
parser = PagingArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
)
cls.add_parser_arguments(parser)
return parser
@classmethod
def add_subparser(cls, parser: argparse.ArgumentParser):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self)->List[str]:
return ['type','initconf']
class Config:
env_file_encoding = 'utf-8'
arbitrary_types_allowed = True
case_sensitive = True
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
field_type = get_type_hints(cls).get(name)
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category)
argparse_group = cls.argparse_groups[category]
else:
argparse_group = command_parser
if get_origin(field_type) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
elif get_origin(field_type) == list:
argparse_group.add_argument(
f"--{name}",
dest=name,
nargs='*',
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
else:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
def _find_root()->Path:
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif (
os.environ.get("VIRTUAL_ENV")
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
or
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
)
):
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root
class InvokeAIAppConfig(InvokeAISettings):
'''
Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
'''
#fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
#fmt: on
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
'''
Initialize InvokeAIAppconfig.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
:param **kwargs: attributes to initialize with
'''
super().__init__(**kwargs)
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
self.parse_args(argv)
if conf is None:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except:
pass
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
self.parse_args(argv)
# restore initialization values
hints = get_type_hints(self)
for k in kwargs:
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
@property
def root_path(self)->Path:
'''
Path to the runtime root directory
'''
if self.root:
return Path(self.root).expanduser()
else:
return self.find_root()
@property
def root_dir(self)->Path:
'''
Alias for above.
'''
return self.root_path
def _resolve(self,partial_path:Path)->Path:
return (self.root_path / partial_path).resolve()
@property
def output_path(self)->Path:
'''
Path to defaults outputs directory.
'''
return self._resolve(self.outdir)
@property
def model_conf_path(self)->Path:
'''
Path to models configuration file.
'''
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self)->Path:
'''
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
'''
return self._resolve(self.legacy_conf_dir)
@property
def cache_dir(self)->Path:
'''
Path to the global cache directory for HuggingFace hub-managed models
'''
return self.models_dir / "hub"
@property
def models_dir(self)->Path:
'''
Path to the models directory
'''
return self._resolve("models")
@property
def embedding_path(self)->Path:
'''
Path to the textual inversion embeddings directory.
'''
return self._resolve(self.embedding_dir) if self.embedding_dir else None
@property
def lora_path(self)->Path:
'''
Path to the LoRA models directory.
'''
return self._resolve(self.lora_dir) if self.lora_dir else None
@property
def autoconvert_path(self)->Path:
'''
Path to the directory containing models to be imported automatically at startup.
'''
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
@property
def gfpgan_model_path(self)->Path:
'''
Path to the GFPGAN model.
'''
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self)->bool:
"""Return true if precision set to float32"""
return self.precision=='float32'
@property
def disable_xformers(self)->bool:
"""Return true if xformers_enabled is false"""
return not self.xformers_enabled
@property
def try_patchmatch(self)->bool:
"""Return true if patchmatch true"""
return self.patchmatch
@staticmethod
def find_root()->Path:
'''
Choose the runtime root directory when not specified on command line or
init file.
'''
return _find_root()
class PagingArgumentParser(argparse.ArgumentParser):
'''
A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file.
'''
def print_help(self, file=None):
text = self.format_help()
pydoc.pager(text)
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
'''
This returns a singleton InvokeAIAppConfig configuration object.
'''
global global_config
if global_config is None or type(global_config)!=cls:
global_config = cls(**kwargs)
return global_config

View File

@ -49,12 +49,13 @@ def create_text_to_image() -> LibraryGraph:
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)
# text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None:
# # TODO: Check if the graph is the same as the default one, and if not, update it
# #if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)

View File

@ -135,6 +135,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here
@ -162,6 +163,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
type: Literal["iterate"] = "iterate"
collection: list[Any] = Field(

View File

@ -270,4 +270,5 @@ class DiskImageStorage(ImageStorageBase):
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config import InvokeAISettings
class InvocationServices:
"""Services that can be used by invocations"""
@ -21,6 +22,7 @@ class InvocationServices:
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
@ -40,6 +42,7 @@ class InvocationServices:
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings=None,
):
self.model_manager = model_manager
self.events = events
@ -52,3 +55,4 @@ class InvocationServices:
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration

View File

@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
]

View File

@ -2,27 +2,25 @@ import os
import sys
import torch
from argparse import Namespace
from invokeai.backend import Args
from omegaconf import OmegaConf
from pathlib import Path
from typing import types
import invokeai.version
from .config import InvokeAISettings
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> ModelManager:
model_config = config.model_conf_path
if not model_config.exists():
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
)
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
logger.info(f'InvokeAI runtime directory is "{config.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@ -32,20 +30,7 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
import diffusers
diffusers.logging.set_verbosity_error()
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = config.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
@ -58,11 +43,11 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(config.conf),
OmegaConf.load(config.model_conf_path),
precision=precision,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
embedding_path = embedding_path,
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
@ -73,12 +58,10 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
# try to autoconvert new models
# autoimport new .ckpt files
if path := config.autoconvert:
model_manager.autoconvert_weights(
conf_path=config.conf,
weights_directory=path,
if config.autoconvert_path:
model_manager.heuristic_import(
config.autoconvert_path,
)
logger.info('Model manager initialized')
return model_manager
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):

View File

@ -1,3 +1,4 @@
import time
import traceback
from threading import Event, Thread, BoundedSemaphore
@ -6,6 +7,7 @@ from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
@ -34,8 +36,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__threadLimit.acquire()
while not stop_event.is_set():
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
logger.debug("Exception while getting from queue: %s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
graph_execution_state = (
@ -124,7 +132,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
logger.error("Error while invoking: %s" % e)
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error=traceback.format_exc()
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(
graph_execution_state.id

View File

@ -1,5 +1,13 @@
import datetime
import numpy as np
def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
SEED_MAX = np.iinfo(np.int32).max
def get_random_seed():
return np.random.randint(0, SEED_MAX)

View File

@ -1,7 +1,6 @@
"""
Initialization file for invokeai.backend
"""
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
@ -12,5 +11,3 @@ from .generator import (
)
from .model_management import ModelManager, SDModelComponent
from .safety_checker import SafetyChecker
from .args import Args
from .globals import Globals

File diff suppressed because it is too large Load Diff

View File

@ -19,10 +19,10 @@ import warnings
from argparse import Namespace
from pathlib import Path
from shutil import get_terminal_size
from typing import get_type_hints
from urllib import request
import npyscreen
import torch
import transformers
from diffusers import AutoencoderKL
from huggingface_hub import HfFolder
@ -38,34 +38,40 @@ from transformers import (
import invokeai.configs as configs
from ...frontend.install.model_install import addModelsForm, process_and_execute
from ...frontend.install.widgets import (
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
CenteredButtonPress,
IntTitleSlider,
set_min_terminal_size,
)
from ..args import PRECISION_CHOICES, Args
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
from .model_install_backend import (
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
from invokeai.backend.config.model_install_backend import (
default_dataset,
download_from_hf,
hf_download_with_resume,
recommended_datasets,
)
from invokeai.app.services.config import (
get_invokeai_config,
InvokeAIAppConfig,
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
# --------------------------globals-----------------------
config = get_invokeai_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
Default_config_file = Path(global_config_dir()) / "models.yaml"
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
Datasets = OmegaConf.load(Dataset_path)
@ -73,17 +79,12 @@ Datasets = OmegaConf.load(Dataset_path)
MIN_COLS = 135
MIN_LINES = 45
PRECISION_CHOICES = ['auto','float16','float32','autocast']
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
# or renaming it and then running invokeai-configure again.
# Place frequently-used startup commands here, one or more per line.
# Examples:
# --outdir=D:\data\images
# --no-nsfw_checker
# --web --host=0.0.0.0
# --steps=20
# -Ak_euler_a -C10.0
"""
@ -96,14 +97,13 @@ If you installed manually from source or with 'pip install': activate the virtua
then run one of the following commands to start InvokeAI.
Web UI:
invokeai --web # (connect to http://localhost:9090)
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
invokeai-web
Command-line interface:
Command-line client:
invokeai
If you installed using an installation script, run:
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
{config.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
Add the '--help' argument to see all of the command-line switches available for use.
"""
@ -216,11 +216,11 @@ def download_realesrgan():
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = os.path.join(
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
config.root, "models/realesrgan/realesr-general-x4v3.pth"
)
wdn_model_dest = os.path.join(
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
config.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
)
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
@ -243,7 +243,7 @@ def download_gfpgan():
"./models/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
model_url, model_dest = model[0], os.path.join(config.root, model[1])
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
@ -253,7 +253,7 @@ def download_codeformer():
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
model_dest = os.path.join(config.root, "models/codeformer/codeformer.pth")
download_with_progress_bar(model_url, model_dest, "CodeFormer")
@ -295,7 +295,7 @@ def download_vaes():
# first the diffusers version
repo_id = "stabilityai/sd-vae-ft-mse"
args = dict(
cache_dir=global_cache_dir("hub"),
cache_dir=config.cache_dir,
)
if not AutoencoderKL.from_pretrained(repo_id, **args):
raise Exception(f"download of {repo_id} failed")
@ -306,7 +306,7 @@ def download_vaes():
if not hf_download_with_resume(
repo_id=repo_id,
model_name=model_name,
model_dir=str(Globals.root / Model_dir / Weights_dir),
model_dir=str(config.root / Model_dir / Weights_dir),
):
raise Exception(f"download of {model_name} failed")
except Exception as e:
@ -321,8 +321,7 @@ def get_root(root: str = None) -> str:
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return Globals.root
return config.root
# -------------------------------------
class editOptsForm(npyscreen.FormMultiPage):
@ -332,7 +331,7 @@ class editOptsForm(npyscreen.FormMultiPage):
def create(self):
program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts
first_time = not (Globals.root / Globals.initfile).exists()
first_time = not (config.root / 'invokeai.yaml').exists()
access_token = HfFolder.get_token()
window_width, window_height = get_terminal_size()
for i in [
@ -366,7 +365,7 @@ class editOptsForm(npyscreen.FormMultiPage):
self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=old_opts.outdir or str(default_output_dir()),
value=str(old_opts.outdir) or str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
@ -381,17 +380,17 @@ class editOptsForm(npyscreen.FormMultiPage):
editable=False,
color="CONTROL",
)
self.safety_checker = self.add_widget_intelligent(
self.nsfw_checker = self.add_widget_intelligent(
npyscreen.Checkbox,
name="NSFW checker",
value=old_opts.safety_checker,
value=old_opts.nsfw_checker,
relx=5,
scroll_exit=True,
)
self.nextrely += 1
for i in [
"If you have an account at HuggingFace you may paste your access token here",
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
"If you have an account at HuggingFace you may optionally paste your access token here",
'to allow InvokeAI to download restricted styles & subjects from the "Concept Library".',
"See https://huggingface.co/settings/tokens",
]:
self.add_widget_intelligent(
@ -435,17 +434,10 @@ class editOptsForm(npyscreen.FormMultiPage):
relx=5,
scroll_exit=True,
)
self.xformers = self.add_widget_intelligent(
self.xformers_enabled = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Enable xformers support if available",
value=old_opts.xformers,
relx=5,
scroll_exit=True,
)
self.ckpt_convert = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Load legacy checkpoint models into memory as diffusers models",
value=old_opts.ckpt_convert,
value=old_opts.xformers_enabled,
relx=5,
scroll_exit=True,
)
@ -480,19 +472,30 @@ class editOptsForm(npyscreen.FormMultiPage):
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Directory containing embedding/textual inversion files:",
value="Directories containing textual inversion and LoRA models (<tab> autocompletes, ctrl-N advances):",
editable=False,
color="CONTROL",
)
self.embedding_path = self.add_widget_intelligent(
self.embedding_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
name=" Textual Inversion Embeddings:",
value=str(default_embedding_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
begin_entry_at=32,
scroll_exit=True,
)
self.lora_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=" LoRA and LyCORIS:",
value=str(default_lora_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
scroll_exit=True,
)
self.nextrely += 1
@ -559,9 +562,9 @@ class editOptsForm(npyscreen.FormMultiPage):
bad_fields.append(
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
)
if not Path(opt.embedding_path).parent.exists():
if not Path(opt.embedding_dir).parent.exists():
bad_fields.append(
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
)
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:\n"
@ -577,13 +580,13 @@ class editOptsForm(npyscreen.FormMultiPage):
for attr in [
"outdir",
"safety_checker",
"nsfw_checker",
"free_gpu_mem",
"max_loaded_models",
"xformers",
"xformers_enabled",
"always_use_cpu",
"embedding_path",
"ckpt_convert",
"embedding_dir",
"lora_dir",
]:
setattr(new_opts, attr, getattr(self, attr).value)
@ -591,6 +594,9 @@ class editOptsForm(npyscreen.FormMultiPage):
new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
# widget library workaround to make max_loaded_models an int rather than a float
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
return new_opts
@ -628,15 +634,14 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace:
opts = Args().parse_args([])
opts = InvokeAIAppConfig(argv=[])
outdir = Path(opts.outdir)
if not outdir.is_absolute():
opts.outdir = str(Globals.root / opts.outdir)
opts.outdir = str(config.root / opts.outdir)
if not init_file.exists():
opts.safety_checker = True
opts.nsfw_checker = True
return opts
def default_user_selections(program_opts: Namespace) -> Namespace:
return Namespace(
starter_models=default_dataset()
@ -690,70 +695,61 @@ def run_console_ui(
# -------------------------------------
def write_opts(opts: Namespace, init_file: Path):
"""
Update the invokeai.init file with values from opts Namespace
Update the invokeai.yaml file with values from current settings.
"""
# touch file if it doesn't exist
if not init_file.exists():
with open(init_file, "w") as f:
f.write(INIT_FILE_PREAMBLE)
# We want to write in the changed arguments without clobbering
# any other initialization values the user has entered. There is
# no good way to do this because of the one-way nature of
# argparse: i.e. --outdir could be --outdir, --out, or -o
# initfile needs to be replaced with a fully structured format
# such as yaml; this is a hack that will work much of the time
args_to_skip = re.compile(
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
)
# fix windows paths
opts.outdir = opts.outdir.replace("\\", "/")
opts.embedding_path = opts.embedding_path.replace("\\", "/")
new_file = f"{init_file}.new"
try:
lines = [x.strip() for x in open(init_file, "r").readlines()]
with open(new_file, "w") as out_file:
for line in lines:
if len(line) > 0 and not args_to_skip.match(line):
out_file.write(line + "\n")
out_file.write(
f"""
--outdir={opts.outdir}
--embedding_path={opts.embedding_path}
--precision={opts.precision}
--max_loaded_models={int(opts.max_loaded_models)}
--{'no-' if not opts.safety_checker else ''}nsfw_checker
--{'no-' if not opts.xformers else ''}xformers
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
{'--always_use_cpu' if opts.always_use_cpu else ''}
"""
)
except OSError as e:
print(f"** An error occurred while writing the init file: {str(e)}")
os.replace(new_file, init_file)
if opts.hf_token:
HfLogin(opts.hf_token)
# this will load current settings
config = InvokeAIAppConfig()
for key,value in opts.__dict__.items():
if hasattr(config,key):
setattr(config,key,value)
with open(init_file,'w', encoding='utf-8') as file:
file.write(config.to_yaml())
# -------------------------------------
def default_output_dir() -> Path:
return Globals.root / "outputs"
return config.root / "outputs"
# -------------------------------------
def default_embedding_dir() -> Path:
return Globals.root / "embeddings"
return config.root / "embeddings"
# -------------------------------------
def default_lora_dir() -> Path:
return config.root / "loras"
# -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path):
opt = default_startup_options(initfile)
opt.hf_token = HfFolder.get_token()
write_opts(opt, initfile)
# -------------------------------------
# Here we bring in
# the legacy Args object in order to parse
# the old init file and write out the new
# yaml format.
def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new = InvokeAIAppConfig(conf={})
fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields:
if hasattr(old,attr):
setattr(new,attr,getattr(old,attr))
# a few places where the field names have changed and we have to
# manually add in the new names/values
new.nsfw_checker = old.safety_checker
new.xformers_enabled = old.xformers
new.conf_path = old.conf
new.embedding_dir = old.embedding_path
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
outfile.write(new.to_yaml())
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
# -------------------------------------
def main():
@ -810,7 +806,8 @@ def main():
opt = parser.parse_args()
# setting a global here
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
global config
config.root = Path(os.path.expanduser(get_root(opt.root) or ""))
errors = set()
@ -818,19 +815,26 @@ def main():
models_to_download = default_user_selections(opt)
# We check for to see if the runtime directory is correctly initialized.
init_file = Path(Globals.root, Globals.initfile)
if not init_file.exists() or not global_config_file().exists():
initialize_rootdir(Globals.root, opt.yes_to_all)
old_init_file = Path(config.root, 'invokeai.init')
new_init_file = Path(config.root, 'invokeai.yaml')
if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
config = get_invokeai_config() # reread defaults
if not config.model_conf_path.exists():
initialize_rootdir(config.root, opt.yes_to_all)
if opt.yes_to_all:
write_default_options(opt, init_file)
write_default_options(opt, new_init_file)
init_options = Namespace(
precision="float32" if opt.full_precision else "float16"
)
else:
init_options, models_to_download = run_console_ui(opt, init_file)
init_options, models_to_download = run_console_ui(opt, new_init_file)
if init_options:
write_opts(init_options, init_file)
write_opts(init_options, new_init_file)
else:
print(
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'

View File

@ -0,0 +1,390 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
import argparse
import shlex
from argparse import ArgumentParser
SAMPLER_CHOICES = [
"ddim",
"ddpm",
"deis",
"lms",
"pndm",
"heun",
"heun_k",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
]
PRECISION_CHOICES = [
"auto",
"float32",
"autocast",
"float16",
]
class FileArgumentParser(ArgumentParser):
"""
Supports reading defaults from an init file.
"""
def convert_arg_line_to_args(self, arg_line):
return shlex.split(arg_line, comments=True)
legacy_parser = FileArgumentParser(
description=
"""
Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden
prompt the command prompt.
""",
fromfile_prefix_chars='@',
)
general_group = legacy_parser.add_argument_group('General')
model_group = legacy_parser.add_argument_group('Model selection')
file_group = legacy_parser.add_argument_group('Input/output')
web_server_group = legacy_parser.add_argument_group('Web server')
render_group = legacy_parser.add_argument_group('Rendering')
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated
general_group.add_argument(
'--version','-V',
action='store_true',
help='Print InvokeAI version number'
)
model_group.add_argument(
'--root_dir',
default=None,
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
)
model_group.add_argument(
'--config',
'-c',
'-config',
dest='conf',
default='./configs/models.yaml',
help='Path to configuration file for alternate models.',
)
model_group.add_argument(
'--model',
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
)
model_group.add_argument(
'--weight_dirs',
nargs='+',
type=str,
help='List of one or more directories that will be auto-scanned for new model weights to import',
)
model_group.add_argument(
'--png_compression','-z',
type=int,
default=6,
choices=range(0,9),
dest='png_compression',
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
)
model_group.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help='Deprecated way to set --precision=float32',
)
model_group.add_argument(
'--max_loaded_models',
dest='max_loaded_models',
type=int,
default=2,
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
)
model_group.add_argument(
'--free_gpu_mem',
dest='free_gpu_mem',
action='store_true',
help='Force free gpu memory before final decoding',
)
model_group.add_argument(
'--sequential_guidance',
dest='sequential_guidance',
action='store_true',
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
"at the expense of speed",
)
model_group.add_argument(
'--xformers',
action=argparse.BooleanOptionalAction,
default=True,
help='Enable/disable xformers support (default enabled if installed)',
)
model_group.add_argument(
"--always_use_cpu",
dest="always_use_cpu",
action="store_true",
help="Force use of CPU even if GPU is available"
)
model_group.add_argument(
'--precision',
dest='precision',
type=str,
choices=PRECISION_CHOICES,
metavar='PRECISION',
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
)
model_group.add_argument(
'--ckpt_convert',
action=argparse.BooleanOptionalAction,
dest='ckpt_convert',
default=True,
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
)
model_group.add_argument(
'--internet',
action=argparse.BooleanOptionalAction,
dest='internet_available',
default=True,
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
)
model_group.add_argument(
'--nsfw_checker',
'--safety_checker',
action=argparse.BooleanOptionalAction,
dest='safety_checker',
default=False,
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
)
model_group.add_argument(
'--autoimport',
default=None,
type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
)
model_group.add_argument(
'--autoconvert',
default=None,
type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
)
model_group.add_argument(
'--patchmatch',
action=argparse.BooleanOptionalAction,
default=True,
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
)
file_group.add_argument(
'--from_file',
dest='infile',
type=str,
help='If specified, load prompts from this file',
)
file_group.add_argument(
'--outdir',
'-o',
type=str,
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
default='outputs',
)
file_group.add_argument(
'--prompt_as_dir',
'-p',
action='store_true',
help='Place images in subdirectories named after the prompt.',
)
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
)
render_group.add_argument(
'-s',
'--steps',
type=int,
default=50,
help='Number of steps'
)
render_group.add_argument(
'-W',
'--width',
type=int,
help='Image width, multiple of 64',
)
render_group.add_argument(
'-H',
'--height',
type=int,
help='Image height, multiple of 64',
)
render_group.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
render_group.add_argument(
'--sampler',
'-A',
'-m',
dest='sampler_name',
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
)
render_group.add_argument(
'--log_tokenization',
'-t',
action='store_true',
help='shows how the prompt is split into tokens'
)
render_group.add_argument(
'-f',
'--strength',
type=float,
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
render_group.add_argument(
'-T',
'-fit',
'--fit',
action=argparse.BooleanOptionalAction,
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
render_group.add_argument(
'--grid',
'-g',
action=argparse.BooleanOptionalAction,
help='generate a grid'
)
render_group.add_argument(
'--embedding_directory',
'--embedding_path',
dest='embedding_path',
default='embeddings',
type=str,
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
)
render_group.add_argument(
'--lora_directory',
dest='lora_path',
default='loras',
type=str,
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
)
render_group.add_argument(
'--embeddings',
action=argparse.BooleanOptionalAction,
default=True,
help='Enable embedding directory (default). Use --no-embeddings to disable.',
)
render_group.add_argument(
'--enable_image_debugging',
action='store_true',
help='Generates debugging image to display'
)
render_group.add_argument(
'--karras_max',
type=int,
default=None,
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
)
# Restoration related args
postprocessing_group.add_argument(
'--no_restore',
dest='restore',
action='store_false',
help='Disable face restoration with GFPGAN or codeformer',
)
postprocessing_group.add_argument(
'--no_upscale',
dest='esrgan',
action='store_false',
help='Disable upscaling with ESRGAN',
)
postprocessing_group.add_argument(
'--esrgan_bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
)
postprocessing_group.add_argument(
'--esrgan_denoise_str',
type=float,
default=0.75,
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
)
postprocessing_group.add_argument(
'--gfpgan_model_path',
type=str,
default='./models/gfpgan/GFPGANv1.4.pth',
help='Indicates the path to the GFPGAN model',
)
web_server_group.add_argument(
'--web',
dest='web',
action='store_true',
help='Start in web server mode.',
)
web_server_group.add_argument(
'--web_develop',
dest='web_develop',
action='store_true',
help='Start in web server development mode.',
)
web_server_group.add_argument(
"--web_verbose",
action="store_true",
help="Enables verbose logging",
)
web_server_group.add_argument(
"--cors",
nargs="*",
type=str,
help="Additional allowed origins, comma-separated",
)
web_server_group.add_argument(
'--host',
type=str,
default='127.0.0.1',
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
)
web_server_group.add_argument(
'--port',
type=int,
default='9090',
help='Web server: Port to listen on'
)
web_server_group.add_argument(
'--certfile',
type=str,
default=None,
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
)
web_server_group.add_argument(
'--keyfile',
type=str,
default=None,
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
)
web_server_group.add_argument(
'--gui',
dest='gui',
action='store_true',
help='Start InvokeAI GUI',
)

View File

@ -19,13 +19,15 @@ from tqdm import tqdm
import invokeai.configs as configs
from ..globals import Globals, global_cache_dir, global_config_dir
from invokeai.app.services.config import get_invokeai_config
from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = get_invokeai_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
@ -47,12 +49,11 @@ Config_preamble = """
def default_config_file():
return Path(global_config_dir()) / "models.yaml"
return config.model_conf_path
def sd_configs():
return Path(global_config_dir()) / "stable-diffusion"
return config.legacy_conf_path
def initial_models():
global Datasets
@ -121,8 +122,9 @@ def install_requested_models(
if scan_at_startup and scan_directory.is_dir():
argument = "--autoconvert"
initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f"{Globals.initfile}.new")
print('** The global initfile is no longer supported; rewrite to support new yaml format **')
initfile = Path(config.root, 'invokeai.init')
replacement = Path(config.root, f"invokeai.init.new")
directory = str(scan_directory).replace("\\", "/")
with open(initfile, "r") as input:
with open(replacement, "w") as output:
@ -150,7 +152,7 @@ def get_root(root: str = None) -> str:
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return Globals.root
return config.root
# ---------------------------------------------
@ -183,7 +185,7 @@ def all_datasets() -> dict:
# look for legacy model.ckpt in models directory and offer to
# normalize its name
def migrate_models_ckpt():
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
model_path = os.path.join(config.root, Model_dir, Weights_dir)
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return
new_name = initial_models()["stable-diffusion-1.4"]["file"]
@ -228,7 +230,7 @@ def _download_repo_or_file(
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
repo_id = mconfig["repo_id"]
filename = mconfig["file"]
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
cache_dir = os.path.join(config.root, Model_dir, Weights_dir)
return hf_download_with_resume(
repo_id=repo_id,
model_dir=cache_dir,
@ -239,9 +241,9 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
# ---------------------------------------------
def download_from_hf(
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
model_class: object, model_name: str, **kwargs
):
path = global_cache_dir(cache_subdir)
path = config.cache_dir
model = model_class.from_pretrained(
model_name,
cache_dir=path,
@ -417,7 +419,7 @@ def new_config_file_contents(
stanza["height"] = mod["height"]
if "file" in mod:
stanza["weights"] = os.path.relpath(
successfully_downloaded[model], start=Globals.root
successfully_downloaded[model], start=config.root
)
stanza["config"] = os.path.normpath(
os.path.join(sd_configs(), mod["config"])
@ -456,7 +458,7 @@ def delete_weights(model_name: str, conf_stanza: dict):
weights = Path(weights)
if not weights.is_absolute():
weights = Path(Globals.root) / weights
weights = Path(config.root) / weights
try:
weights.unlink()
except OSError as e:

File diff suppressed because it is too large Load Diff

View File

@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
downsampling = 8
@ -71,19 +72,6 @@ class InvokeAIGeneratorOutput:
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
@ -175,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
'''
Return list of all the schedulers that we currently handle.
'''
return list(self.scheduler_map.keys())
return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@ -226,10 +220,10 @@ class Inpaint(Img2Img):
def generate(self,
mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 10,
seam_steps: int = 30,
tile_size: int = 32,
inpaint_replace=False,
infill_method=None,

View File

@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
from __future__ import annotations
import math
from typing import Tuple, Union
import cv2
import numpy as np
@ -59,7 +60,7 @@ class Inpaint(Img2Img):
writeable=False,
)
def infill_patchmatch(self, im: Image.Image) -> Image:
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
@ -75,18 +76,18 @@ class Inpaint(Img2Img):
return im_patched
def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: int = None
) -> Image:
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a, *tile_size).copy()
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
@ -127,7 +128,9 @@ class Inpaint(Img2Img):
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
def mask_edge(
self, mask: Image.Image, edge_size: int, edge_blur: int
) -> Image.Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
@ -206,15 +209,15 @@ class Inpaint(Img2Img):
cfg_scale,
ddim_eta,
conditioning,
init_image: PIL.Image.Image | torch.FloatTensor,
mask_image: PIL.Image.Image | torch.FloatTensor,
init_image: Image.Image | torch.FloatTensor,
mask_image: Image.Image | torch.FloatTensor,
strength: float,
mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 10,
seam_steps: int = 30,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False,
@ -222,7 +225,7 @@ class Inpaint(Img2Img):
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None,
**kwargs,
):
@ -239,7 +242,7 @@ class Inpaint(Img2Img):
self.inpaint_width = inpaint_width
self.inpaint_height = inpaint_height
if isinstance(init_image, PIL.Image.Image):
if isinstance(init_image, Image.Image):
self.pil_image = init_image.copy()
# Do infill
@ -250,8 +253,8 @@ class Inpaint(Img2Img):
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
)
elif infill_method == "solid":
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = Image.alpha_composite(solid_bg, init_image)
else:
raise ValueError(
f"Non-supported infill type {infill_method}", infill_method
@ -269,7 +272,7 @@ class Inpaint(Img2Img):
# Create init tensor
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
if isinstance(mask_image, PIL.Image.Image):
if isinstance(mask_image, Image.Image):
self.pil_mask = mask_image.copy()
debug_image(
mask_image,

View File

@ -1,122 +0,0 @@
"""
invokeai.backend.globals defines a small number of global variables that would
otherwise have to be passed through long and complex call chains.
It defines a Namespace object named "Globals" that contains
the attributes:
- root - the root directory under which "models" and "outputs" can be found
- initfile - path to the initialization file
- try_patchmatch - option to globally disable loading of 'patchmatch' module
- always_use_cpu - force use of CPU even if GPU is available
"""
import os
import os.path as osp
from argparse import Namespace
from pathlib import Path
from typing import Union
Globals = Namespace()
# Where to look for the initialization file and other key components
Globals.initfile = "invokeai.init"
Globals.models_file = "models.yaml"
Globals.models_dir = "models"
Globals.config_dir = "configs"
Globals.autoscan_dir = "weights"
Globals.converted_ckpts_dir = "converted_ckpts"
# Set the default root directory. This can be overwritten by explicitly
# passing the `--root <directory>` argument on the command line.
# logic is:
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
# 3) use ~/invokeai
if os.environ.get("INVOKEAI_ROOT"):
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
elif (
os.environ.get("VIRTUAL_ENV")
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
):
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
else:
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
# Try loading patchmatch
Globals.try_patchmatch = True
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
Globals.always_use_cpu = False
# Whether the internet is reachable for dynamic downloads
# The CLI will test connectivity at startup time.
Globals.internet_available = True
# Whether to disable xformers
Globals.disable_xformers = False
# Low-memory tradeoff for guidance calculations.
Globals.sequential_guidance = False
# whether we are forcing full precision
Globals.full_precision = False
# whether we should convert ckpt files into diffusers models on the fly
Globals.ckpt_convert = True
# logging tokenization everywhere
Globals.log_tokenization = False
def global_config_file() -> Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file)
def global_config_dir() -> Path:
return Path(Globals.root, Globals.config_dir)
def global_models_dir() -> Path:
return Path(Globals.root, Globals.models_dir)
def global_autoscan_dir() -> Path:
return Path(Globals.root, Globals.autoscan_dir)
def global_converted_ckpts_dir() -> Path:
return Path(global_models_dir(), Globals.converted_ckpts_dir)
def global_set_root(root_dir: Union[str, Path]):
Globals.root = root_dir
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
"""
Returns Path to the model cache directory. If a subdirectory
is provided, it will be appended to the end of the path, allowing
for Hugging Face-style conventions. Currently, Hugging Face has
moved all models into the "hub" subfolder, so for any pretrained
HF model, use:
global_cache_dir('hub')
The legacy location for transformers used to be global_cache_dir('transformers')
and global_cache_dir('diffusers') for diffusers.
"""
home: str = os.getenv("HF_HOME")
if home is None:
home = os.getenv("XDG_CACHE_HOME")
if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + "huggingface"
if home is not None:
return Path(home, subdir)
else:
return Path(Globals.root, "models", subdir)

View File

@ -6,7 +6,7 @@ be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
class PatchMatch:
"""
@ -21,9 +21,10 @@ class PatchMatch:
@classmethod
def _load_patch_match(self):
config = get_invokeai_config()
if self.tried_load:
return
if Globals.try_patchmatch:
if config.try_patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:

View File

@ -33,12 +33,11 @@ from PIL import Image, ImageOps
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir
from invokeai.app.services.config import get_invokeai_config
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor):
self.heatmap = heatmap
@ -84,14 +83,15 @@ class Txt2Mask(object):
def __init__(self, device="cpu", refined=False):
logger.info("Initializing clipseg model for text to mask inference")
config = get_invokeai_config()
# BUG: we are not doing anything with the device option at this time
self.device = device
self.processor = AutoProcessor.from_pretrained(
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
CLIPSEG_MODEL, cache_dir=config.cache_dir
)
self.model = CLIPSegForImageSegmentation.from_pretrained(
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
CLIPSEG_MODEL, cache_dir=config.cache_dir
)
@torch.no_grad()

View File

@ -26,7 +26,7 @@ import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir, global_config_dir
from invokeai.app.services.config import get_invokeai_config
from .model_manager import ModelManager, SDLegacyType
@ -47,6 +47,7 @@ from diffusers import (
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
@ -73,7 +74,6 @@ from transformers import (
from ..stable_diffusion import StableDiffusionGeneratorPipeline
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
@ -842,7 +842,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub")
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
)
keys = list(checkpoint.keys())
@ -897,7 +897,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = global_cache_dir("hub")
cache_dir = get_invokeai_config().cache_dir
config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
@ -969,7 +969,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint):
cache_dir = global_cache_dir("hub")
cache_dir = get_invokeai_config().cache_dir
text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
)
@ -1092,7 +1092,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
config = get_invokeai_config()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
@ -1105,7 +1105,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir("hub")
cache_dir = config.cache_dir
pipeline_class = (
StableDiffusionGeneratorPipeline
if return_generator_pipeline
@ -1129,25 +1129,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if model_type == SDLegacyType.V2_v:
original_config_file = (
global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml"
config.legacy_conf_path / "v2-inference-v.yaml"
)
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif model_type == SDLegacyType.V2_e:
original_config_file = (
global_config_dir() / "stable-diffusion" / "v2-inference.yaml"
config.legacy_conf_path / "v2-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = (
global_config_dir()
/ "stable-diffusion"
/ "v1-inpainting-inference.yaml"
config.legacy_conf_path / "v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V1:
original_config_file = (
global_config_dir() / "stable-diffusion" / "v1-inference.yaml"
config.legacy_conf_path / "v1-inference.yaml"
)
else:
@ -1209,6 +1207,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == 'unipc':
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
@ -1297,7 +1297,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker",
cache_dir=global_cache_dir("hub"),
cache_dir=config.cache_dir,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir

View File

@ -36,8 +36,6 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from transformers import (
CLIPTextModel,
CLIPTokenizer,
@ -49,9 +47,9 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from invokeai.app.services.config import get_invokeai_config
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = auto()
V1_INPAINT = auto()
@ -100,6 +98,7 @@ class ModelManager(object):
if not isinstance(config, DictConfig):
config = OmegaConf.load(config)
self.config = config
self.globals = get_invokeai_config()
self.precision = precision
self.device = torch.device(device_type)
self.max_loaded_models = max_loaded_models
@ -292,7 +291,7 @@ class ModelManager(object):
"""
# if we are converting legacy files automatically, then
# there are no legacy ckpts!
if Globals.ckpt_convert:
if self.globals.ckpt_convert:
return False
info = self.model_info(model_name)
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
@ -502,13 +501,13 @@ class ModelManager(object):
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
safety_checker=None, local_files_only=not Globals.internet_available
safety_checker=None, local_files_only=not self.globals.internet_available
)
if "vae" in mconfig and mconfig["vae"] is not None:
if vae := self._load_vae(mconfig["vae"]):
pipeline_args.update(vae=vae)
if not isinstance(name_or_path, Path):
pipeline_args.update(cache_dir=global_cache_dir("hub"))
pipeline_args.update(cache_dir=self.globals.cache_dir)
if using_fp16:
pipeline_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
@ -560,10 +559,9 @@ class ModelManager(object):
width = mconfig.width
height = mconfig.height
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root, weights))
root_dir = self.globals.root_dir
config = str(root_dir / config)
weights = str(root_dir / weights)
# Convert to diffusers and return a diffusers pipeline
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
@ -578,11 +576,7 @@ class ModelManager(object):
vae_path = None
if vae:
vae_path = (
vae
if os.path.isabs(vae)
else os.path.normpath(os.path.join(Globals.root, vae))
)
vae_path = str(root_dir / vae)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@ -614,9 +608,7 @@ class ModelManager(object):
)
if "path" in mconfig and mconfig["path"] is not None:
path = Path(mconfig["path"])
if not path.is_absolute():
path = Path(Globals.root, path).resolve()
path = self.globals.root_dir / Path(mconfig["path"])
return path
elif "repo_id" in mconfig:
return mconfig["repo_id"]
@ -864,25 +856,16 @@ class ModelManager(object):
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
self.logger.debug("SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
elif model_type == SDLegacyType.V1_INPAINT:
self.logger.debug("SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
)
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml",
elif model_type == SDLegacyType.V2_v:
self.logger.debug("SD-v2-v model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
elif model_type == SDLegacyType.V2_e:
self.logger.debug("SD-v2-e model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
elif model_type == SDLegacyType.V2:
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
@ -909,9 +892,7 @@ class ModelManager(object):
self.logger.debug(f"Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
)
diffuser_path = self.globals.root_dir / "models/converted_ckpts" / model_path.stem
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
@ -1044,9 +1025,7 @@ class ModelManager(object):
"""
yaml_str = OmegaConf.to_yaml(self.config)
if not os.path.isabs(config_file_path):
config_file_path = os.path.normpath(
os.path.join(Globals.root, config_file_path)
)
config_file_path = self.globals.model_conf_path
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(self.preamble())
@ -1078,7 +1057,8 @@ class ModelManager(object):
"""
# Three transformer models to check: bert, clip and safety checker, and
# the diffusers as well
models_dir = Path(Globals.root, "models")
config = get_invokeai_config()
models_dir = config.root_dir / "models"
legacy_locations = [
Path(
models_dir,
@ -1090,8 +1070,8 @@ class ModelManager(object):
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
),
]
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
legacy_cache_dir = config.cache_dir / "../diffusers"
legacy_locations.extend(list(legacy_cache_dir.glob("*")))
legacy_layout = False
for model in legacy_locations:
legacy_layout = legacy_layout or model.exists()
@ -1113,7 +1093,7 @@ class ModelManager(object):
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
hub = global_cache_dir("hub")
hub = config.cache_dir
else:
hub = models_dir / "hub"
@ -1152,13 +1132,12 @@ class ModelManager(object):
if str(source).startswith(("http:", "https:", "ftp:")):
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory = self.globals.root_dir / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)
resolved_path = Path(source)
source = self.globals.root_dir / source
resolved_path = source
return resolved_path
def _invalidate_cached_model(self, model_name: str) -> None:
@ -1208,7 +1187,7 @@ class ModelManager(object):
path = name_or_path
else:
owner, repo = name_or_path.split("/")
path = Path(global_cache_dir("hub") / f"models--{owner}--{repo}")
path = self.globals.cache_dir / f"models--{owner}--{repo}"
if not path.exists():
return None
hashpath = path / "checksum.sha256"
@ -1228,7 +1207,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
with open(hashpath, "w") as f:
f.write(hash)
return hash
@ -1269,8 +1248,8 @@ class ModelManager(object):
using_fp16 = self.precision == "float16"
vae_args.update(
cache_dir=global_cache_dir("hub"),
local_files_only=not Globals.internet_available,
cache_dir=self.globals.cache_dir,
local_files_only=not self.globals.internet_available,
)
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
@ -1308,7 +1287,7 @@ class ModelManager(object):
@classmethod
def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(global_cache_dir("hub"))
cache_info = scan_cache_dir(get_invokeai_config().cache_dir)
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
@ -1325,9 +1304,10 @@ class ModelManager(object):
@staticmethod
def _abs_path(path: str | Path) -> Path:
globals = get_invokeai_config()
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root, path).resolve()
return Path(globals.root_dir, path).resolve()
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:

View File

@ -16,67 +16,59 @@ from compel.prompt_parser import (
FlattenedPrompt,
Fragment,
PromptParser,
Conjunction,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
):
def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_string
)
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
tokenizer = model.tokenizer
compel = Compel(
tokenizer=tokenizer,
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False
truncate_long_prompts=False,
)
config = get_invokeai_config()
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
(
positive_prompt_string,
negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or config.log_tokenization:
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
tokens_count = get_max_token_count(tokenizer, positive_prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get("cross_attention_control", None),
)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None))
return uc, c, ec
def get_prompt_structure(
prompt_string, skip_normalize_legacy_blend: bool = False
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
@ -87,18 +79,17 @@ def get_prompt_structure(
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
positive_prompt: Conjunction
if legacy_blend is not None:
positive_prompt = legacy_blend
positive_conjunction = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
return positive_prompt, negative_prompt
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
@ -245,22 +236,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
)
flattened_prompts = []
weights = []
for i, x in enumerate(parsed_conjunctions):
if len(x.prompts)>0:
flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
def split_weighted_subprompts(text, skip_normalize=False) -> list:
"""

View File

@ -6,7 +6,7 @@ import numpy as np
import torch
import invokeai.backend.util.logging as logger
from ..globals import Globals
from invokeai.app.services.config import get_invokeai_config
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
@ -17,11 +17,11 @@ class CodeFormerRestoration:
def __init__(
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
if not os.path.isabs(codeformer_dir):
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path)
self.globals = get_invokeai_config()
codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists()
if not self.codeformer_model_exists:
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
@ -71,9 +71,7 @@ class CodeFormerRestoration:
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath=os.path.join(
Globals.root, "models", "gfpgan", "weights"
),
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)

View File

@ -7,14 +7,13 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = get_invokeai_config()
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = os.path.abspath(
os.path.join(Globals.root, gfpgan_model_path)
)
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
@ -33,7 +32,7 @@ class GFPGAN:
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd()
os.chdir(os.path.join(Globals.root, "models"))
os.chdir(self.globals.root_dir / 'models')
try:
from gfpgan import GFPGANer

View File

@ -1,4 +1,3 @@
import os
import warnings
import numpy as np
@ -7,7 +6,8 @@ from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
config = get_invokeai_config()
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
@ -30,12 +30,8 @@ class ESRGAN:
upscale=4,
act_type="prelu",
)
model_path = os.path.join(
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
)
wdn_model_path = os.path.join(
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
)
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
scale = 4
bg_upsampler = RealESRGANer(

View File

@ -15,7 +15,7 @@ from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from .globals import global_cache_dir
from invokeai.app.services.config import get_invokeai_config
from .util import CPU_DEVICE
class SafetyChecker(object):
@ -26,10 +26,11 @@ class SafetyChecker(object):
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
config = get_invokeai_config()
try:
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub")
safety_model_path = config.cache_dir
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,

View File

@ -18,15 +18,15 @@ from huggingface_hub import (
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None):
"""
Initialize the Concepts object. May optionally pass a root directory.
"""
self.root = root or Globals.root
self.config = get_invokeai_config()
self.root = root or self.config.root
self.hf_api = HfApi()
self.local_concepts = dict()
self.concept_list = None
@ -58,7 +58,7 @@ class HuggingFaceConceptsLibrary(object):
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
elif Globals.internet_available is True:
elif self.config.internet_available is True:
try:
models = self.hf_api.list_models(
filter=ModelFilter(model_name="sd-concepts-library/")

View File

@ -33,8 +33,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
@ -44,7 +43,6 @@ from .diffusion import (
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
from .textual_inversion_manager import TextualInversionManager
@dataclass
class PipelineIntermediateState:
run_id: str
@ -348,10 +346,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
config = get_invokeai_config()
if (
torch.cuda.is_available()
and is_xformers_available()
and not Globals.disable_xformers
and not config.disable_xformers
):
self.enable_xformers_memory_efficient_attention()
else:
@ -509,10 +508,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
if timesteps is None:
self.scheduler.set_timesteps(
num_inference_steps, device=self._model_group.device_for(self.unet)
)
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState
@ -545,6 +547,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
@ -726,11 +729,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(
num_inference_steps,
strength,
device=self._model_group.device_for(self.unet),
)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like(
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
@ -756,13 +755,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(
self, num_inference_steps: int, strength: float, device
self, num_inference_steps: int, strength: float, device=None
) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
num_inference_steps, strength, device=device
num_inference_steps, strength, device=scheduler_device
)
# Workaround for low strength resulting in zero timesteps.
# TODO: submit upstream fix for zero-step img2img
@ -796,9 +801,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_image.dim() == 3:
init_image = init_image.unsqueeze(0)
timesteps, _ = self.get_img2img_timesteps(
num_inference_steps, strength, device=device
)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
# 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents

View File

@ -10,6 +10,7 @@ import diffusers
import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from torch import nn
@ -352,8 +353,7 @@ def restore_default_cross_attention(
else:
remove_attention_function(model)
def override_cross_attention(model, context: Context, is_running_diffusers=False):
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -372,15 +372,13 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
@ -388,21 +386,8 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next(
(
p.slice_size
for p in old_attn_processors.values()
if type(p) is SlicedAttnProcessor
),
default_slice_size,
)
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
return None
def get_cross_attention_modules(
model, which: CrossAttentionType

View File

@ -5,11 +5,12 @@ from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
from .cross_attention_control import (
Arguments,
@ -17,8 +18,8 @@ from .cross_attention_control import (
CrossAttentionType,
SwapCrossAttnContext,
get_cross_attention_modules,
override_cross_attention,
restore_default_cross_attention,
setup_cross_attention_control_attention_processors,
)
from .cross_attention_map_saving import AttentionMapSaver
@ -31,7 +32,6 @@ ModelForwardCallback: TypeAlias = Union[
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@ -72,31 +72,43 @@ class InvokeAIDiffuserComponent:
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
config = get_invokeai_config()
self.conditioning = None
self.model = model
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance
self.sequential_guidance = config.sequential_guidance
@classmethod
@contextmanager
def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
):
do_swap = (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
old_attn_processors = None
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
old_attn_processor = None
if do_swap:
old_attn_processor = self.override_cross_attention(
extra_conditioning_info, step_count=step_count
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
)
try:
yield None
finally:
if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor)
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()

View File

@ -0,0 +1 @@
from .schedulers import SCHEDULER_MAP

View File

@ -0,0 +1,23 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
ddpm=(DDPMScheduler, dict()),
deis=(DEISMultistepScheduler, dict()),
lms=(LMSDiscreteScheduler, dict()),
pndm=(PNDMScheduler, dict()),
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
euler_a=(EulerAncestralDiscreteScheduler, dict()),
kdpm_2=(KDPM2DiscreteScheduler, dict()),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
)

View File

@ -7,7 +7,6 @@
This is the backend to "textual_inversion.py"
"""
import argparse
import logging
import math
import os
@ -47,8 +46,7 @@ from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from ..args import ArgFormatter, PagingArgumentParser
from ..globals import Globals, global_cache_dir
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
@ -90,8 +88,9 @@ def save_progress(
def parse_args():
config = InvokeAIAppConfig(argv=[])
parser = PagingArgumentParser(
description="Textual inversion training", formatter_class=ArgFormatter
description="Textual inversion training"
)
general_group = parser.add_argument_group("General")
model_group = parser.add_argument_group("Models and Paths")
@ -112,7 +111,7 @@ def parse_args():
"--root_dir",
"--root",
type=Path,
default=Globals.root,
default=config.root,
help="Path to the invokeai runtime directory",
)
general_group.add_argument(
@ -127,7 +126,7 @@ def parse_args():
general_group.add_argument(
"--output_dir",
type=Path,
default=f"{Globals.root}/text-inversion-model",
default=f"{config.root}/text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
model_group.add_argument(
@ -528,6 +527,7 @@ def get_full_repo_name(
def do_textual_inversion_training(
config: InvokeAIAppConfig,
model: str,
train_data_dir: Path,
output_dir: Path,
@ -580,7 +580,7 @@ def do_textual_inversion_training(
# setting up things the way invokeai expects them
if not os.path.isabs(output_dir):
output_dir = os.path.join(Globals.root, output_dir)
output_dir = os.path.join(config.root, output_dir)
logging_dir = output_dir / logging_dir
@ -628,7 +628,7 @@ def do_textual_inversion_training(
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
models_conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
models_conf = OmegaConf.load(config.model_conf_path)
model_conf = models_conf.get(model, None)
assert model_conf is not None, f"Unknown model: {model}"
assert (
@ -640,7 +640,7 @@ def do_textual_inversion_training(
assert (
pretrained_model_name_or_path
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
pipeline_args = dict(cache_dir=global_cache_dir("hub"))
pipeline_args = dict(cache_dir=config.cache_dir)
# Load tokenizer
if tokenizer_name:

View File

@ -4,17 +4,16 @@ from contextlib import nullcontext
import torch
from torch import autocast
from invokeai.backend.globals import Globals
from invokeai.app.services.config import get_invokeai_config
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
if Globals.always_use_cpu:
config = get_invokeai_config()
if config.always_use_cpu:
return CPU_DEVICE
if torch.cuda.is_available():
return torch.device("cuda")
@ -33,7 +32,8 @@ def choose_precision(device: torch.device) -> str:
def torch_dtype(device: torch.device) -> torch.dtype:
if Globals.full_precision:
config = get_invokeai_config()
if config.full_precision:
return torch.float32
if choose_precision(device) == "float16":
return torch.float16

View File

@ -2,34 +2,37 @@
"""invokeai.util.logging
Logging class for InvokeAI that produces console messages that follow
the conventions established in InvokeAI 1.X through 2.X.
Logging class for InvokeAI that produces console messages
One way to use it:
Usage:
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(__name__)
logger.critical('this is critical')
logger.error('this is an error')
logger.warning('this is a warning')
logger.info('this is info')
logger.debug('this is debugging')
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
(or)
logger = InvokeAILogger.getLogger(__name__) // To use the filename
logger.critical('this is critical') // Critical Message
logger.error('this is an error') // Error Message
logger.warning('this is a warning') // Warning Message
logger.info('this is info') // Info Message
logger.debug('this is debugging') // Debug Message
Console messages:
### this is critical
*** this is an error ***
** this is a warning
>> this is info
| this is debugging
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
Another way:
import invokeai.backend.util.logging as ialog
ialogger.debug('this is a debugging message')
Alternate Method (in this case the logger name will be set to InvokeAI):
import invokeai.backend.util.logging as IAILogger
IAILogger.debug('this is a debugging message')
"""
import logging
# module level functions
def debug(msg, *args, **kwargs):
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
@ -55,49 +58,47 @@ def disable(level=logging.CRITICAL):
def basicConfig(**kwargs):
InvokeAILogger.getLogger().basicConfig(**kwargs)
def getLogger(name: str=None)->logging.Logger:
def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name)
class InvokeAILogFormatter(logging.Formatter):
'''
Repurposed from:
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
Custom Formatting for the InvokeAI Logger
'''
crit_fmt = "### %(msg)s"
err_fmt = "*** %(msg)s"
warn_fmt = "** %(msg)s"
info_fmt = ">> %(msg)s"
dbg_fmt = " | %(msg)s"
def __init__(self):
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
# Color Codes
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
cyan = "\x1b[36;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
# Log Format
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
# Format Map
FORMATS = {
logging.DEBUG: cyan + format + reset,
logging.INFO: grey + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
}
def format(self, record):
# Remember the format used when the logging module
# was installed (in the event that this formatter is
# used with the vanilla logging module.
format_orig = self._style._fmt
if record.levelno == logging.DEBUG:
self._style._fmt = InvokeAILogFormatter.dbg_fmt
if record.levelno == logging.INFO:
self._style._fmt = InvokeAILogFormatter.info_fmt
if record.levelno == logging.WARNING:
self._style._fmt = InvokeAILogFormatter.warn_fmt
if record.levelno == logging.ERROR:
self._style._fmt = InvokeAILogFormatter.err_fmt
if record.levelno == logging.CRITICAL:
self._style._fmt = InvokeAILogFormatter.crit_fmt
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
return formatter.format(record)
# parent class does the work
result = super().format(record)
self._style._fmt = format_orig
return result
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(self, name:str='invokeai')->logging.Logger:
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
if name not in self.loggers:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)

View File

@ -4,17 +4,21 @@ from .parse_seed_weights import parse_seed_weights
SAMPLER_CHOICES = [
"ddim",
"k_dpm_2_a",
"k_dpm_2",
"k_dpmpp_2_a",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
# diffusers:
"ddpm",
"deis",
"lms",
"pndm",
"heun",
'heun_k',
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
]

File diff suppressed because it is too large Load Diff

View File

@ -1,497 +0,0 @@
"""
Readline helper functions for invoke.py.
You may import the global singleton `completer` to get access to the
completer object itself. This is useful when you want to autocomplete
seeds:
from invokeai.frontend.CLI.readline import completer
completer.add_seed(18247566)
completer.add_seed(9281839)
"""
import atexit
import os
import re
from ...backend.args import Args
from ...backend.globals import Globals
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except (ImportError, ModuleNotFoundError) as e:
print(f"** An error occurred when loading the readline module: {str(e)}")
readline_available = False
IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF")
WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors")
TEXT_EXTENSIONS = (".txt", ".TXT")
CONFIG_EXTENSIONS = (".yaml", ".yml")
COMMANDS = (
"--steps",
"-s",
"--seed",
"-S",
"--iterations",
"-n",
"--width",
"-W",
"--height",
"-H",
"--cfg_scale",
"-C",
"--threshold",
"--perlin",
"--grid",
"-g",
"--individual",
"-i",
"--save_intermediates",
"--init_img",
"-I",
"--init_mask",
"-M",
"--init_color",
"--strength",
"-f",
"--variants",
"-v",
"--outdir",
"-o",
"--sampler",
"-A",
"-m",
"--embedding_path",
"--device",
"--grid",
"-g",
"--facetool",
"-ft",
"--facetool_strength",
"-G",
"--codeformer_fidelity",
"-cf",
"--upscale",
"-U",
"-save_orig",
"--save_original",
"--log_tokenization",
"-t",
"--hires_fix",
"--inpaint_replace",
"-r",
"--png_compression",
"-z",
"--text_mask",
"-tm",
"--h_symmetry_time_pct",
"--v_symmetry_time_pct",
"!fix",
"!fetch",
"!replay",
"!history",
"!search",
"!clear",
"!models",
"!switch",
"!import_model",
"!optimize_model",
"!convert_model",
"!edit_model",
"!del_model",
"!mask",
"!triggers",
)
MODEL_COMMANDS = (
"!switch",
"!edit_model",
"!del_model",
)
CKPT_MODEL_COMMANDS = ("!optimize_model",)
WEIGHT_COMMANDS = (
"!import_model",
"!convert_model",
)
IMG_PATH_COMMANDS = ("--outdir[=\s]",)
TEXT_PATH_COMMANDS = ("!replay",)
IMG_FILE_COMMANDS = (
"!fix",
"!fetch",
"!mask",
"--init_img[=\s]",
"-I",
"--init_mask[=\s]",
"-M",
"--init_color[=\s]",
"--embedding_path[=\s]",
)
path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$"
weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$"
text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$"
class Completer(object):
def __init__(self, options, models={}):
self.options = sorted(options)
self.models = models
self.seeds = set()
self.matches = list()
self.default_dir = None
self.linebuffer = None
self.auto_history_active = True
self.extensions = None
self.concepts = None
self.embedding_terms = set()
return
def complete(self, text, state):
"""
Completes invoke command line.
BUG: it doesn't correctly complete files that have spaces in the name.
"""
buffer = readline.get_line_buffer()
if state == 0:
# extensions defined, so go directly into path completion mode
if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file
elif re.search(path_regexp, buffer):
do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer)
self.matches = self._path_completions(
text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut
)
# looking for a seed
elif re.search("(-S\s*|--seed[=\s])\d*$", buffer):
self.matches = self._seed_completions(text, state)
# looking for an embedding concept
elif re.search("<[\w-]*$", buffer):
self.matches = self._concept_completions(text, state)
# looking for a model
elif re.match("^" + "|".join(MODEL_COMMANDS), buffer):
self.matches = self._model_completions(text, state)
# looking for a ckpt model
elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer):
self.matches = self._model_completions(text, state, ckpt_only=True)
elif re.search(weight_regexp, buffer):
self.matches = self._path_completions(
text,
state,
WEIGHT_EXTENSIONS,
default_dir=Globals.root,
)
elif re.search(text_regexp, buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [s for s in self.options if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def complete_extensions(self, extensions: list):
"""
If called with a list of extensions, will force completer
to do file path completions.
"""
self.extensions = extensions
def add_history(self, line):
"""
Pass thru to readline
"""
if not self.auto_history_active:
readline.add_history(line)
def clear_history(self):
"""
Pass clear_history() thru to readline
"""
readline.clear_history()
def search_history(self, match: str):
"""
Like show_history() but only shows items that
contain the match string.
"""
self.show_history(match)
def remove_history_item(self, pos):
readline.remove_history_item(pos)
def add_seed(self, seed):
"""
Add a seed to the autocomplete list for display when -S is autocompleted.
"""
if seed is not None:
self.seeds.add(str(seed))
def set_default_dir(self, path):
self.default_dir = path
def set_options(self, options):
self.options = options
def get_line(self, index):
try:
line = self.get_history_item(index)
except IndexError:
return None
return line
def get_current_history_length(self):
return readline.get_current_history_length()
def get_history_item(self, index):
return readline.get_history_item(index)
def show_history(self, match=None):
"""
Print the session history using the pydoc pager
"""
import pydoc
lines = list()
h_len = self.get_current_history_length()
if h_len < 1:
print("<empty history>")
return
for i in range(0, h_len):
line = self.get_history_item(i + 1)
if match and match not in line:
continue
lines.append(f"[{i+1}] {line}")
pydoc.pager("\n".join(lines))
def set_line(self, line) -> None:
"""
Set the default string displayed in the next line of input.
"""
self.linebuffer = line
readline.redisplay()
def update_models(self, models: dict) -> None:
"""
update our list of models
"""
self.models = models
def _seed_completions(self, text, state):
m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ""
partial = text
matches = list()
for s in self.seeds:
if s.startswith(partial):
matches.append(switch + s)
matches.sort()
return matches
def add_embedding_terms(self, terms: list[str]):
self.embedding_terms = set(terms)
if self.concepts:
self.embedding_terms.update(set(self.concepts.list_concepts()))
def _concept_completions(self, text, state):
if self.concepts is None:
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
self.concepts = HuggingFaceConceptsLibrary()
self.embedding_terms.update(set(self.concepts.list_concepts()))
else:
self.embedding_terms.update(set(self.concepts.list_concepts()))
partial = text[1:] # this removes the leading '<'
if len(partial) == 0:
return list(self.embedding_terms) # whole dump - think if user wants this!
matches = list()
for concept in self.embedding_terms:
if concept.startswith(partial):
matches.append(f"<{concept}>")
matches.sort()
return matches
def _model_completions(self, text, state, ckpt_only=False):
m = re.search("(!switch\s+)(\w*)", text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ""
partial = text
matches = list()
for s in self.models:
format = self.models[s]["format"]
if format == "vae":
continue
if ckpt_only and format != "ckpt":
continue
if s.startswith(partial):
matches.append(switch + s)
matches.sort()
return matches
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def _path_completions(
self, text, state, extensions, shortcut_ok=True, default_dir: str = ""
):
# separate the switch from the partial path
match = re.search("^(-\w|--\w+=?)(.*)", text)
if match is None:
switch = None
partial_path = text
else:
switch, partial_path = match.groups()
partial_path = partial_path.lstrip()
matches = list()
path = os.path.expanduser(partial_path)
if os.path.isdir(path):
dir = path
elif os.path.dirname(path) != "":
dir = os.path.dirname(path)
else:
dir = default_dir if os.path.exists(default_dir) else ""
path = os.path.join(dir, path)
dir_list = os.listdir(dir or ".")
if shortcut_ok and os.path.exists(self.default_dir) and dir == "":
dir_list += os.listdir(self.default_dir)
for node in dir_list:
if node.startswith(".") and len(node) > 1:
continue
full_path = os.path.join(dir, node)
if not (node.endswith(extensions) or os.path.isdir(full_path)):
continue
if path and not full_path.startswith(path):
continue
if switch is None:
match_path = os.path.join(dir, node)
matches.append(
match_path + "/" if os.path.isdir(full_path) else match_path
)
elif os.path.isdir(full_path):
matches.append(
switch + os.path.join(os.path.dirname(full_path), node) + "/"
)
elif node.endswith(extensions):
matches.append(switch + os.path.join(os.path.dirname(full_path), node))
return matches
class DummyCompleter(Completer):
def __init__(self, options):
super().__init__(options)
self.history = list()
def add_history(self, line):
self.history.append(line)
def clear_history(self):
self.history = list()
def get_current_history_length(self):
return len(self.history)
def get_history_item(self, index):
return self.history[index - 1]
def remove_history_item(self, index):
return self.history.pop(index - 1)
def set_line(self, line):
print(f"# {line}")
def generic_completer(commands: list) -> Completer:
if readline_available:
completer = Completer(commands, [])
readline.set_completer(completer.complete)
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
else:
completer = DummyCompleter(commands)
return completer
def get_completer(opt: Args, models=[]) -> Completer:
if readline_available:
completer = Completer(COMMANDS, models)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(False)
completer.auto_history_active = False
except:
completer.auto_history_active = True
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
outdir = os.path.expanduser(opt.outdir)
if os.path.isabs(outdir):
histfile = os.path.join(outdir, ".invoke_history")
else:
histfile = os.path.join(Globals.root, outdir, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
os.replace(histfile, newname)
atexit.register(readline.write_history_file, histfile)
else:
completer = DummyCompleter(COMMANDS)
return completer

View File

@ -1,30 +0,0 @@
'''
This is a modularized version of the sd-metadata.py script,
which retrieves and prints the metadata from a series of generated png files.
'''
import sys
import json
from invokeai.backend.image_util import retrieve_metadata
def print_metadata():
if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out their metadata.")
exit(-1)
filenames = sys.argv[1:]
for f in filenames:
try:
metadata = retrieve_metadata(f)
print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4))
except FileNotFoundError:
sys.stderr.write(f'{f} not found\n')
continue
except PermissionError:
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
continue
if __name__== '__main__':
print_metadata()

View File

@ -23,7 +23,6 @@ from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals, global_config_dir
from ...backend.config.model_install_backend import (
Dataset_path,
@ -41,11 +40,13 @@ from .widgets import (
TextBox,
set_min_terminal_size,
)
from invokeai.app.services.config import get_invokeai_config
# minimum size for the UI
MIN_COLS = 120
MIN_LINES = 45
config = get_invokeai_config()
class addModelsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled
@ -453,9 +454,9 @@ def main():
opt = parser.parse_args()
# setting a global here
Globals.root = os.path.expanduser(get_root(opt.root) or "")
config.root = os.path.expanduser(get_root(opt.root) or "")
if not global_config_dir().exists():
if not (config.conf_path / '..' ).exists():
logger.info(
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
)

View File

@ -8,7 +8,6 @@ import argparse
import curses
import os
import sys
import traceback
import warnings
from argparse import Namespace
from pathlib import Path
@ -20,20 +19,13 @@ from diffusers import logging as dlogging
from npyscreen import widget
from omegaconf import OmegaConf
from ...backend.globals import (
Globals,
global_cache_dir,
global_config_file,
global_models_dir,
global_set_root,
)
import invokeai.backend.util.logging as logger
from invokeai.services.config import get_invokeai_config
from ...backend.model_management import ModelManager
from ...frontend.install.widgets import FloatTitleSlider
DEST_MERGED_MODEL_DIR = "merged_models"
config = get_invokeai_config()
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
@ -60,7 +52,7 @@ def merge_diffusion_models(
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
cache_dir=kwargs.get("cache_dir", config.cache_dir),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
@ -94,7 +86,7 @@ def merge_diffusion_models_and_commit(
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
config_file = global_config_file()
config_file = config.model_conf_path
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
@ -106,7 +98,7 @@ def merge_diffusion_models_and_commit(
merged_pipe = merge_diffusion_models(
model_ids_or_paths, alpha, interp, force, **kwargs
)
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / merged_model_name
@ -126,7 +118,7 @@ def _parse_args() -> Namespace:
parser.add_argument(
"--root_dir",
type=Path,
default=Globals.root,
default=config.root,
help="Path to the invokeai runtime directory",
)
parser.add_argument(
@ -398,7 +390,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
super().__init__()
conf = OmegaConf.load(global_config_file())
conf = OmegaConf.load(config.model_conf_path)
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
@ -429,7 +421,7 @@ def run_cli(args: Namespace):
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(global_config_file()))
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
@ -440,9 +432,9 @@ def run_cli(args: Namespace):
def main():
args = _parse_args()
global_set_root(args.root_dir)
config.root = args.root_dir
cache_dir = str(global_cache_dir("hub"))
cache_dir = config.cache_dir
os.environ[
"HF_HOME"
] = cache_dir # because not clear the merge pipeline is honoring cache_dir

View File

@ -21,14 +21,17 @@ from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals, global_set_root
from ...backend.training import do_textual_inversion_training, parse_args
from invokeai.app.services.config import get_invokeai_config
from ...backend.training import (
do_textual_inversion_training,
parse_args
)
TRAINING_DATA = "text-inversion-training-data"
TRAINING_DIR = "text-inversion-output"
CONF_FILE = "preferences.conf"
config = None
class textualInversionForm(npyscreen.FormMultiPageAction):
resolutions = [512, 768, 1024]
@ -122,7 +125,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
value=str(
saved_args.get(
"train_data_dir",
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
config.root_dir / TRAINING_DATA / default_placeholder_token,
)
),
scroll_exit=True,
@ -135,7 +138,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
value=str(
saved_args.get(
"output_dir",
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
config.root_dir / TRAINING_DIR / default_placeholder_token,
)
),
scroll_exit=True,
@ -241,9 +244,9 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
placeholder = self.placeholder_token.value
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
self.train_data_dir.value = str(
Path(Globals.root) / TRAINING_DATA / placeholder
config.root_dir / TRAINING_DATA / placeholder
)
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
def on_ok(self):
@ -284,7 +287,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
return True
def get_model_names(self) -> Tuple[List[str], int]:
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
model_names = [
idx
for idx in sorted(list(conf.keys()))
@ -367,7 +370,7 @@ def copy_to_embeddings_folder(args: dict):
"""
source = Path(args["output_dir"], "learned_embeds.bin")
dest_dir_name = args["placeholder_token"].strip("<>")
destination = Path(Globals.root, "embeddings", dest_dir_name)
destination = config.root_dir / "embeddings" / dest_dir_name
os.makedirs(destination, exist_ok=True)
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
shutil.copy(source, destination)
@ -383,7 +386,7 @@ def save_args(args: dict):
"""
Save the current argument values to an omegaconf file
"""
dest_dir = Path(Globals.root) / TRAINING_DIR
dest_dir = config.root_dir / TRAINING_DIR
os.makedirs(dest_dir, exist_ok=True)
conf_file = dest_dir / CONF_FILE
conf = OmegaConf.create(args)
@ -394,7 +397,7 @@ def previous_args() -> dict:
"""
Get the previous arguments used.
"""
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
try:
conf = OmegaConf.load(conf_file)
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
@ -420,7 +423,7 @@ def do_front_end(args: Namespace):
save_args(args)
try:
do_textual_inversion_training(**args)
do_textual_inversion_training(get_invokeai_config(),**args)
copy_to_embeddings_folder(args)
except Exception as e:
logger.error("An exception occurred during training. The exception was:")
@ -430,13 +433,20 @@ def do_front_end(args: Namespace):
def main():
global config
args = parse_args()
global_set_root(args.root_dir or Globals.root)
config = get_invokeai_config(argv=[])
# change root if needed
if args.root_dir:
config.root = args.root_dir
try:
if args.front_end:
do_front_end(args)
else:
do_textual_inversion_training(**vars(args))
do_textual_inversion_training(config,**vars(args))
except AssertionError as e:
logger.error(e)
sys.exit(-1)

View File

@ -1,13 +0,0 @@
{
"plugins": [
[
"transform-imports",
{
"lodash": {
"transform": "lodash/${member}",
"preventFullImport": true
}
}
]
]
}

View File

@ -35,3 +35,7 @@ stats.html
!.yarn/releases
!.yarn/sdks
!.yarn/versions
# Yalc
.yalc
yalc.lock

View File

@ -5,6 +5,7 @@ import { PluginOption, UserConfig } from 'vite';
import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
export const packageConfig: UserConfig = {
base: './',
@ -16,9 +17,10 @@ export const packageConfig: UserConfig = {
dts({
insertTypesEntry: true,
}),
cssInjectedByJsPlugin(),
],
build: {
chunkSizeWarningLimit: 1500,
cssCodeSplit: true,
lib: {
entry: path.resolve(__dirname, '../src/index.ts'),
name: 'InvokeAIUI',
@ -30,6 +32,7 @@ export const packageConfig: UserConfig = {
globals: {
react: 'React',
'react-dom': 'ReactDOM',
'@emotion/react': 'EmotionReact',
},
},
},

View File

@ -15,15 +15,3 @@ The `postinstall` script patches a few packages and runs the Chakra CLI to gener
### Patch `@chakra-ui/cli`
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
### Patch `redux-persist`
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
### Patch `redux-deep-persist`
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.

View File

@ -37,7 +37,7 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
Start everything in dev mode:
1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web`
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### Production builds

View File

@ -21,7 +21,6 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
@ -63,11 +62,13 @@
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@floating-ui/react-dom": "^2.0.0",
"@fontsource/inter": "^4.5.15",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"downshift": "^7.6.0",
"formik": "^2.2.9",
"framer-motion": "^10.12.4",
"fuse.js": "^6.6.2",
@ -88,17 +89,14 @@
"react-i18next": "^12.2.2",
"react-icons": "^4.7.1",
"react-konva": "^18.2.7",
"react-konva-utils": "^1.0.4",
"react-redux": "^8.0.5",
"react-rnd": "^10.4.1",
"react-transition-group": "^4.4.5",
"react-resizable-panels": "^0.0.42",
"react-use": "^17.4.0",
"react-virtuoso": "^4.3.5",
"react-zoom-pan-pinch": "^3.0.7",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"redux-remember": "^3.3.1",
"roarr": "^7.15.0",
"serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0",
@ -118,6 +116,7 @@
"@types/node": "^18.16.2",
"@types/react": "^18.2.0",
"@types/react-dom": "^18.2.1",
"@types/react-redux": "^7.1.25",
"@types/react-transition-group": "^4.4.5",
"@types/uuid": "^9.0.0",
"@typescript-eslint/eslint-plugin": "^5.59.1",
@ -143,6 +142,7 @@
"terser": "^5.17.1",
"ts-toolbelt": "^9.6.0",
"vite": "^4.3.3",
"vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.2.0",

View File

@ -1,24 +0,0 @@
diff --git a/node_modules/redux-deep-persist/lib/types.d.ts b/node_modules/redux-deep-persist/lib/types.d.ts
index b67b8c2..7fc0fa1 100644
--- a/node_modules/redux-deep-persist/lib/types.d.ts
+++ b/node_modules/redux-deep-persist/lib/types.d.ts
@@ -35,6 +35,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
whitelist?: Array<string>;
transforms?: Array<Transform<HSS, ESS, S, RS>>;
throttle?: number;
+ debounce?: number;
migrate?: PersistMigrate;
stateReconciler?: false | StateReconciler<S>;
getStoredState?: (config: PersistConfig<S, RS, HSS, ESS>) => Promise<PersistedState>;
diff --git a/node_modules/redux-deep-persist/src/types.ts b/node_modules/redux-deep-persist/src/types.ts
index 398ac19..cbc5663 100644
--- a/node_modules/redux-deep-persist/src/types.ts
+++ b/node_modules/redux-deep-persist/src/types.ts
@@ -91,6 +91,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
whitelist?: Array<string>;
transforms?: Array<Transform<HSS, ESS, S, RS>>;
throttle?: number;
+ debounce?: number;
migrate?: PersistMigrate;
stateReconciler?: false | StateReconciler<S>;
/**

View File

@ -1,116 +0,0 @@
diff --git a/node_modules/redux-persist/es/createPersistoid.js b/node_modules/redux-persist/es/createPersistoid.js
index 8b43b9a..184faab 100644
--- a/node_modules/redux-persist/es/createPersistoid.js
+++ b/node_modules/redux-persist/es/createPersistoid.js
@@ -6,6 +6,7 @@ export default function createPersistoid(config) {
var whitelist = config.whitelist || null;
var transforms = config.transforms || [];
var throttle = config.throttle || 0;
+ var debounce = config.debounce || 0;
var storageKey = "".concat(config.keyPrefix !== undefined ? config.keyPrefix : KEY_PREFIX).concat(config.key);
var storage = config.storage;
var serialize;
@@ -28,30 +29,37 @@ export default function createPersistoid(config) {
var timeIterator = null;
var writePromise = null;
- var update = function update(state) {
- // add any changed keys to the queue
- Object.keys(state).forEach(function (key) {
- if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
+ // Timer for debounced `update()`
+ let timer = 0;
- if (lastState[key] === state[key]) return; // value unchanged? noop
+ function update(state) {
+ // Debounce the update
+ clearTimeout(timer);
+ timer = setTimeout(() => {
+ // add any changed keys to the queue
+ Object.keys(state).forEach(function (key) {
+ if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
- if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
+ if (lastState[key] === state[key]) return; // value unchanged? noop
- keysToProcess.push(key); // add key to queue
- }); //if any key is missing in the new state which was present in the lastState,
- //add it for processing too
+ if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
- Object.keys(lastState).forEach(function (key) {
- if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
- keysToProcess.push(key);
- }
- }); // start the time iterator if not running (read: throttle)
+ keysToProcess.push(key); // add key to queue
+ }); //if any key is missing in the new state which was present in the lastState,
+ //add it for processing too
- if (timeIterator === null) {
- timeIterator = setInterval(processNextKey, throttle);
- }
+ Object.keys(lastState).forEach(function (key) {
+ if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
+ keysToProcess.push(key);
+ }
+ }); // start the time iterator if not running (read: throttle)
+
+ if (timeIterator === null) {
+ timeIterator = setInterval(processNextKey, throttle);
+ }
- lastState = state;
+ lastState = state;
+ }, debounce)
};
function processNextKey() {
diff --git a/node_modules/redux-persist/es/types.js.flow b/node_modules/redux-persist/es/types.js.flow
index c50d3cd..39d8be2 100644
--- a/node_modules/redux-persist/es/types.js.flow
+++ b/node_modules/redux-persist/es/types.js.flow
@@ -19,6 +19,7 @@ export type PersistConfig = {
whitelist?: Array<string>,
transforms?: Array<Transform>,
throttle?: number,
+ debounce?: number,
migrate?: (PersistedState, number) => Promise<PersistedState>,
stateReconciler?: false | Function,
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
diff --git a/node_modules/redux-persist/lib/types.js.flow b/node_modules/redux-persist/lib/types.js.flow
index c50d3cd..39d8be2 100644
--- a/node_modules/redux-persist/lib/types.js.flow
+++ b/node_modules/redux-persist/lib/types.js.flow
@@ -19,6 +19,7 @@ export type PersistConfig = {
whitelist?: Array<string>,
transforms?: Array<Transform>,
throttle?: number,
+ debounce?: number,
migrate?: (PersistedState, number) => Promise<PersistedState>,
stateReconciler?: false | Function,
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
diff --git a/node_modules/redux-persist/src/types.js b/node_modules/redux-persist/src/types.js
index c50d3cd..39d8be2 100644
--- a/node_modules/redux-persist/src/types.js
+++ b/node_modules/redux-persist/src/types.js
@@ -19,6 +19,7 @@ export type PersistConfig = {
whitelist?: Array<string>,
transforms?: Array<Transform>,
throttle?: number,
+ debounce?: number,
migrate?: (PersistedState, number) => Promise<PersistedState>,
stateReconciler?: false | Function,
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
diff --git a/node_modules/redux-persist/types/types.d.ts b/node_modules/redux-persist/types/types.d.ts
index b3733bc..2a1696c 100644
--- a/node_modules/redux-persist/types/types.d.ts
+++ b/node_modules/redux-persist/types/types.d.ts
@@ -35,6 +35,7 @@ declare module "redux-persist/es/types" {
whitelist?: Array<string>;
transforms?: Array<Transform<HSS, ESS, S, RS>>;
throttle?: number;
+ debounce?: number;
migrate?: PersistMigrate;
stateReconciler?: false | StateReconciler<S>;
/**

View File

@ -25,7 +25,7 @@
"common": {
"hotkeysLabel": "Hotkeys",
"themeLabel": "Theme",
"languagePickerLabel": "Language Picker",
"languagePickerLabel": "Language",
"reportBugLabel": "Report Bug",
"githubLabel": "Github",
"discordLabel": "Discord",
@ -54,7 +54,7 @@
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Nodes",
"nodes": "Node Editor",
"postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing",
@ -102,7 +102,8 @@
"generate": "Generate",
"openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?"
"areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt"
},
"gallery": {
"generations": "Generations",
@ -449,13 +450,14 @@
"cfgScale": "CFG Scale",
"width": "Width",
"height": "Height",
"sampler": "Sampler",
"scheduler": "Scheduler",
"seed": "Seed",
"imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle",
"shuffle": "Shuffle Seed",
"noiseThreshold": "Noise Threshold",
"perlinNoise": "Perlin Noise",
"noiseSettings": "Noise",
"variations": "Variations",
"variationAmount": "Variation Amount",
"seedWeights": "Seed Weights",
@ -470,6 +472,8 @@
"scale": "Scale",
"otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling",
"seamlessXAxis": "X Axis",
"seamlessYAxis": "Y Axis",
"hiresOptim": "High Res Optimization",
"hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size",
@ -527,7 +531,8 @@
"useCanvasBeta": "Use Canvas Beta Layout",
"enableImageDebugging": "Enable Image Debugging",
"useSlidersForAll": "Use Sliders For All Options",
"autoShowProgress": "Auto Show Progress Images",
"showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
@ -535,7 +540,10 @@
"consoleLogLevel": "Log Level",
"shouldLogToConsole": "Console Logging",
"developer": "Developer",
"general": "General"
"general": "General",
"generation": "Generation",
"ui": "User Interface",
"availableSchedulers": "Available Schedulers"
},
"toast": {
"serverError": "Server Error",
@ -544,13 +552,14 @@
"canceled": "Processing Canceled",
"tempFoldersEmptied": "Temp Folder Emptied",
"uploadFailed": "Upload failed",
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
"uploadFailedUnableToLoadDesc": "Unable to load file",
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied",
"imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded",
"imageNotLoadedDesc": "No image found to send to image to image module",
"imageNotLoadedDesc": "Could not find image",
"imageSavedToGallery": "Image Saved to Gallery",
"canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image",
@ -645,7 +654,8 @@
"betaClear": "Clear",
"betaDarkenOutside": "Darken Outside",
"betaLimitToBox": "Limit To Box",
"betaPreserveMasked": "Preserve Masked"
"betaPreserveMasked": "Preserve Masked",
"antialiasing": "Antialiasing"
},
"ui": {
"showProgressImages": "Show Progress Images",

View File

@ -1,46 +1,44 @@
import ImageUploader from 'common/components/ImageUploader';
import ProgressBar from 'features/system/components/ProgressBar';
import SiteHeader from 'features/system/components/SiteHeader';
import ProgressBar from 'features/system/components/ProgressBar';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import useToastWatcher from 'features/system/hooks/useToastWatcher';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
memo,
PropsWithChildren,
useCallback,
useEffect,
useState,
} from 'react';
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
import { motion, AnimatePresence } from 'framer-motion';
import Loading from 'common/components/Loading/Loading';
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
import { PartialAppConfig } from 'app/types/invokeai';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { configChanged } from 'features/system/store/configSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useLogger } from 'app/logging/useLogger';
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import { languageSelector } from 'features/system/store/systemSelectors';
import i18n from 'i18n';
import Toaster from './Toaster';
import GlobalHotkeys from './GlobalHotkeys';
const DEFAULT_CONFIG = {};
interface Props extends PropsWithChildren {
interface Props {
config?: PartialAppConfig;
headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
}
const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
useToastWatcher();
useGlobalHotkeys();
const log = useLogger();
const App = ({
config = DEFAULT_CONFIG,
headerComponent,
setIsReady,
}: Props) => {
const language = useAppSelector(languageSelector);
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
const log = useLogger();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
@ -48,23 +46,33 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
const [loadingOverridden, setLoadingOverridden] = useState(false);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch();
useEffect(() => {
i18n.changeLanguage(language);
}, [language]);
useEffect(() => {
log.info({ namespace: 'App', data: config }, 'Received config');
dispatch(configChanged(config));
}, [dispatch, config, log]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
}, [setColorMode, currentTheme]);
const handleOverrideClicked = useCallback(() => {
setLoadingOverridden(true);
}, []);
useEffect(() => {
if (isApplicationReady && setIsReady) {
setIsReady(true);
}
return () => {
setIsReady && setIsReady(false);
};
}, [isApplicationReady, setIsReady]);
return (
<>
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
{isLightboxEnabled && <Lightbox />}
<ImageUploader>
@ -76,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
w={APP_WIDTH}
h={APP_HEIGHT}
>
{children || <SiteHeader />}
{headerComponent || <SiteHeader />}
<Flex
gap={4}
w={{ base: '100vw', xl: 'full' }}
@ -84,11 +92,13 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
flexDir={{ base: 'column', xl: 'row' }}
>
<InvokeTabs />
<ImageGalleryPanel />
</Flex>
</Grid>
</ImageUploader>
<GalleryDrawer />
<ParametersDrawer />
<AnimatePresence>
{!isApplicationReady && !loadingOverridden && (
<motion.div
@ -121,8 +131,10 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
<Portal>
<FloatingGalleryButton />
</Portal>
<ProgressImagePreview />
</Grid>
<Toaster />
<GlobalHotkeys />
</>
);
};

View File

@ -0,0 +1,44 @@
import { Flex, Spinner, Tooltip } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { memo } from 'react';
const selector = createSelector(systemSelector, (system) => {
const { isUploading } = system;
let tooltip = '';
if (isUploading) {
tooltip = 'Uploading...';
}
return {
tooltip,
shouldShow: isUploading,
};
});
export const AuxiliaryProgressIndicator = () => {
const { shouldShow, tooltip } = useAppSelector(selector);
if (!shouldShow) {
return null;
}
return (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
color: 'base.600',
}}
>
<Tooltip label={tooltip} placement="right" hasArrow>
<Spinner />
</Tooltip>
</Flex>
);
};
export default memo(AuxiliaryProgressIndicator);

View File

@ -2,7 +2,15 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import {
setActiveTab,
toggleGalleryPanel,
toggleParametersPanel,
togglePinGalleryPanel,
togglePinParametersPanel,
} from 'features/ui/store/uiSlice';
import { isEqual } from 'lodash-es';
import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector(
@ -20,7 +28,11 @@ const globalHotkeysSelector = createSelector(
// TODO: Does not catch keypresses while focused in an input. Maybe there is a way?
export const useGlobalHotkeys = () => {
/**
* Logical component. Handles app-level global hotkeys.
* @returns null
*/
const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch();
const { shift } = useAppSelector(globalHotkeysSelector);
@ -36,4 +48,40 @@ export const useGlobalHotkeys = () => {
{ keyup: true, keydown: true },
[shift]
);
useHotkeys('o', () => {
dispatch(toggleParametersPanel());
});
useHotkeys(['shift+o'], () => {
dispatch(togglePinParametersPanel());
});
useHotkeys('g', () => {
dispatch(toggleGalleryPanel());
});
useHotkeys(['shift+g'], () => {
dispatch(togglePinGalleryPanel());
});
useHotkeys('1', () => {
dispatch(setActiveTab('txt2img'));
});
useHotkeys('2', () => {
dispatch(setActiveTab('img2img'));
});
useHotkeys('3', () => {
dispatch(setActiveTab('unifiedCanvas'));
});
useHotkeys('4', () => {
dispatch(setActiveTab('nodes'));
});
return null;
};
export default memo(GlobalHotkeys);

View File

@ -1,18 +1,13 @@
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
import React, {
lazy,
memo,
PropsWithChildren,
ReactNode,
useEffect,
} from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from 'app/store/store';
import { persistor } from '../store/persistor';
import { OpenAPI } from 'services/api';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
import '@fontsource/inter/400.css';
import '@fontsource/inter/500.css';
import '@fontsource/inter/600.css';
import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import Loading from '../../common/components/Loading/Loading';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
@ -28,9 +23,17 @@ interface Props extends PropsWithChildren {
apiUrl?: string;
token?: string;
config?: PartialAppConfig;
headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
}
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
const InvokeAIUI = ({
apiUrl,
token,
config,
headerComponent,
setIsReady,
}: Props) => {
useEffect(() => {
// configure API client token
if (token) {
@ -57,13 +60,15 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config}>{children}</App>
<App
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>
</Provider>
</React.StrictMode>
);

View File

@ -1,4 +1,8 @@
import { ChakraProvider, extendTheme } from '@chakra-ui/react';
import {
ChakraProvider,
createLocalStorageManager,
extendTheme,
} from '@chakra-ui/react';
import { ReactNode, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme';
@ -9,15 +13,8 @@ import { greenTeaThemeColors } from 'theme/colors/greenTea';
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
import { lightThemeColors } from 'theme/colors/lightTheme';
import { oceanBlueColors } from 'theme/colors/oceanBlue';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
import '@fontsource/inter/400.css';
import '@fontsource/inter/500.css';
import '@fontsource/inter/600.css';
import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import '@fontsource/inter/variable.css';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
@ -32,6 +29,8 @@ const THEMES = {
ocean: oceanBlueColors,
};
const manager = createLocalStorageManager('@@invokeai-color-mode');
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
const { i18n } = useTranslation();
@ -51,7 +50,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction;
}, [direction]);
return <ChakraProvider theme={theme}>{children}</ChakraProvider>;
return (
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>
);
}
export default ThemeLocaleProvider;

View File

@ -0,0 +1,65 @@
import { useToast, UseToastOptions } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { toastQueueSelector } from 'features/system/store/systemSelectors';
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
import { useCallback, useEffect } from 'react';
export type MakeToastArg = string | UseToastOptions;
/**
* Makes a toast from a string or a UseToastOptions object.
* If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms.
*/
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
if (typeof arg === 'string') {
return {
title: arg,
status: 'info',
isClosable: true,
duration: 2500,
};
}
return { status: 'info', isClosable: true, duration: 2500, ...arg };
};
/**
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
* @returns null
*/
const Toaster = () => {
const dispatch = useAppDispatch();
const toastQueue = useAppSelector(toastQueueSelector);
const toast = useToast();
useEffect(() => {
toastQueue.forEach((t) => {
toast(t);
});
toastQueue.length > 0 && dispatch(clearToastQueue());
}, [dispatch, toast, toastQueue]);
return null;
};
/**
* Returns a function that can be used to make a toast.
* @example
* const toaster = useAppToaster();
* toaster('Hello world!');
* toaster({ title: 'Hello world!', status: 'success' });
* @returns A function that can be used to make a toast.
* @see makeToast
* @see MakeToastArg
* @see UseToastOptions
*/
export const useAppToaster = () => {
const dispatch = useAppDispatch();
const toaster = useCallback(
(arg: MakeToastArg) => dispatch(addToast(makeToast(arg))),
[dispatch]
);
return toaster;
};
export default Toaster;

View File

@ -1,17 +1,28 @@
// TODO: use Enums?
export const DIFFUSERS_SCHEDULERS: Array<string> = [
export const SCHEDULERS = [
'ddim',
'plms',
'k_lms',
'dpmpp_2',
'k_dpm_2',
'k_dpm_2_a',
'k_dpmpp_2',
'k_euler',
'k_euler_a',
'k_heun',
];
'lms',
'euler',
'euler_k',
'euler_a',
'dpmpp_2s',
'dpmpp_2m',
'dpmpp_2m_k',
'kdpm_2',
'kdpm_2_a',
'deis',
'ddpm',
'pndm',
'heun',
'heun_k',
'unipc',
] as const;
export type Scheduler = (typeof SCHEDULERS)[number];
export const isScheduler = (x: string): x is Scheduler =>
SCHEDULERS.includes(x as Scheduler);
// Valid image widths
export const WIDTHS: Array<number> = Array.from(Array(64)).map(

View File

@ -1,26 +1,20 @@
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelectors';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
export const readinessSelector = createSelector(
[
generationSelector,
systemSelector,
initialCanvasImageSelector,
activeTabNameSelector,
],
(generation, system, initialCanvasImage, activeTabName) => {
[generationSelector, systemSelector, activeTabNameSelector],
(generation, system, activeTabName) => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
initialImage,
seed,
isImageToImageEnabled,
} = generation;
const { isProcessing, isConnected } = system;
@ -34,7 +28,7 @@ export const readinessSelector = createSelector(
reasonsWhyNotReady.push('Missing prompt');
}
if (isImageToImageEnabled && !initialImage) {
if (activeTabName === 'img2img' && !initialImage) {
isReady = false;
reasonsWhyNotReady.push('No initial image selected');
}
@ -64,10 +58,5 @@ export const readinessSelector = createSelector(
// All good
return { isReady, reasonsWhyNotReady };
},
{
memoizeOptions: {
equalityCheck: isEqual,
resultEqualityCheck: isEqual,
},
}
defaultSelectorOptions
);

View File

@ -1,209 +1,209 @@
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai';
// import type { RootState } from 'app/store/store';
// import {
// frontendToBackendParameters,
// FrontendToBackendParametersConfig,
// } from 'common/util/parameterTranslation';
// import dateFormat from 'dateformat';
// import {
// GalleryCategory,
// GalleryState,
// removeImage,
// } from 'features/gallery/store/gallerySlice';
// import {
// generationRequested,
// modelChangeRequested,
// modelConvertRequested,
// modelMergingRequested,
// setIsProcessing,
// } from 'features/system/store/systemSlice';
// import { InvokeTabName } from 'features/ui/store/tabMap';
// import { Socket } from 'socket.io-client';
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import {
frontendToBackendParameters,
FrontendToBackendParametersConfig,
} from 'common/util/parameterTranslation';
import dateFormat from 'dateformat';
import {
GalleryCategory,
GalleryState,
removeImage,
} from 'features/gallery/store/gallerySlice';
import {
generationRequested,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setIsProcessing,
} from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { Socket } from 'socket.io-client';
// /**
// * Returns an object containing all functions which use `socketio.emit()`.
// * i.e. those which make server requests.
// */
// const makeSocketIOEmitters = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
// socketio: Socket
// ) => {
// // We need to dispatch actions to redux and get pieces of state from the store.
// const { dispatch, getState } = store;
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
// return {
// emitGenerateImage: (generationMode: InvokeTabName) => {
// dispatch(setIsProcessing(true));
return {
emitGenerateImage: (generationMode: InvokeTabName) => {
dispatch(setIsProcessing(true));
// const state: RootState = getState();
const state: RootState = getState();
// const {
// generation: generationState,
// postprocessing: postprocessingState,
// system: systemState,
// canvas: canvasState,
// } = state;
const {
generation: generationState,
postprocessing: postprocessingState,
system: systemState,
canvas: canvasState,
} = state;
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
// {
// generationMode,
// generationState,
// postprocessingState,
// canvasState,
// systemState,
// };
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode,
generationState,
postprocessingState,
canvasState,
systemState,
};
// dispatch(generationRequested());
dispatch(generationRequested());
// const { generationParameters, esrganParameters, facetoolParameters } =
// frontendToBackendParameters(frontendToBackendParametersConfig);
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
// socketio.emit(
// 'generateImage',
// generationParameters,
// esrganParameters,
// facetoolParameters
// );
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
facetoolParameters
);
// // we need to truncate the init_mask base64 else it takes up the whole log
// // TODO: handle maintaining masks for reproducibility in future
// if (generationParameters.init_mask) {
// generationParameters.init_mask = generationParameters.init_mask
// .substr(0, 64)
// .concat('...');
// }
// if (generationParameters.init_img) {
// generationParameters.init_img = generationParameters.init_img
// .substr(0, 64)
// .concat('...');
// }
// we need to truncate the init_mask base64 else it takes up the whole log
// TODO: handle maintaining masks for reproducibility in future
if (generationParameters.init_mask) {
generationParameters.init_mask = generationParameters.init_mask
.substr(0, 64)
.concat('...');
}
if (generationParameters.init_img) {
generationParameters.init_img = generationParameters.init_img
.substr(0, 64)
.concat('...');
}
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generation requested: ${JSON.stringify({
// ...generationParameters,
// ...esrganParameters,
// ...facetoolParameters,
// })}`,
// })
// );
// },
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...facetoolParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: {
// upscalingLevel,
// upscalingDenoising,
// upscalingStrength,
// },
// } = getState();
const {
postprocessing: {
upscalingLevel,
upscalingDenoising,
upscalingStrength,
},
} = getState();
// const esrganParameters = {
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
// };
// socketio.emit('runPostprocessing', imageToProcess, {
// type: 'esrgan',
// ...esrganParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `ESRGAN upscale requested: ${JSON.stringify({
// file: imageToProcess.url,
// ...esrganParameters,
// })}`,
// })
// );
// },
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
const esrganParameters = {
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
};
socketio.emit('runPostprocessing', imageToProcess, {
type: 'esrgan',
...esrganParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
// } = getState();
const {
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
} = getState();
// const facetoolParameters: Record<string, unknown> = {
// facetool_strength: facetoolStrength,
// };
const facetoolParameters: Record<string, unknown> = {
facetool_strength: facetoolStrength,
};
// if (facetoolType === 'codeformer') {
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
// }
if (facetoolType === 'codeformer') {
facetoolParameters.codeformer_fidelity = codeformerFidelity;
}
// socketio.emit('runPostprocessing', imageToProcess, {
// type: facetoolType,
// ...facetoolParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
// {
// file: imageToProcess.url,
// ...facetoolParameters,
// }
// )}`,
// })
// );
// },
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
// const { url, uuid, category, thumbnail } = imageToDelete;
// dispatch(removeImage(imageToDelete));
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
// },
// emitRequestImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { earliest_mtime } = gallery.categories[category];
// socketio.emit('requestImages', category, earliest_mtime);
// },
// emitRequestNewImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { latest_mtime } = gallery.categories[category];
// socketio.emit('requestLatestImages', category, latest_mtime);
// },
// emitCancelProcessing: () => {
// socketio.emit('cancel');
// },
// emitRequestSystemConfig: () => {
// socketio.emit('requestSystemConfig');
// },
// emitSearchForModels: (modelFolder: string) => {
// socketio.emit('searchForModels', modelFolder);
// },
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
// socketio.emit('addNewModel', modelConfig);
// },
// emitDeleteModel: (modelName: string) => {
// socketio.emit('deleteModel', modelName);
// },
// emitConvertToDiffusers: (
// modelToConvert: InvokeAI.InvokeModelConversionProps
// ) => {
// dispatch(modelConvertRequested());
// socketio.emit('convertToDiffusers', modelToConvert);
// },
// emitMergeDiffusersModels: (
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
// ) => {
// dispatch(modelMergingRequested());
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
// },
// emitRequestModelChange: (modelName: string) => {
// dispatch(modelChangeRequested());
// socketio.emit('requestModelChange', modelName);
// },
// emitSaveStagingAreaImageToGallery: (url: string) => {
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
// },
// emitRequestEmptyTempFolder: () => {
// socketio.emit('requestEmptyTempFolder');
// },
// };
// };
socketio.emit('runPostprocessing', imageToProcess, {
type: facetoolType,
...facetoolParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
{
file: imageToProcess.url,
...facetoolParameters,
}
)}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);
},
emitRequestImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { earliest_mtime } = gallery.categories[category];
socketio.emit('requestImages', category, earliest_mtime);
},
emitRequestNewImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { latest_mtime } = gallery.categories[category];
socketio.emit('requestLatestImages', category, latest_mtime);
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig');
},
emitSearchForModels: (modelFolder: string) => {
socketio.emit('searchForModels', modelFolder);
},
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
socketio.emit('addNewModel', modelConfig);
},
emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName);
},
emitConvertToDiffusers: (
modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert);
},
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName);
},
emitSaveStagingAreaImageToGallery: (url: string) => {
socketio.emit('requestSaveStagingAreaImageToGallery', url);
},
emitRequestEmptyTempFolder: () => {
socketio.emit('requestEmptyTempFolder');
},
};
};
// export default makeSocketIOEmitters;
export default makeSocketIOEmitters;
export default {};

View File

@ -0,0 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { InvokeTabName } from 'features/ui/store/tabMap';
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');

View File

@ -0,0 +1,8 @@
export const LOCALSTORAGE_KEYS = [
'chakra-ui-color-mode',
'i18nextLng',
'ROARR_FILTER',
'ROARR_LOG',
];
export const LOCALSTORAGE_PREFIX = '@@invokeai-';

View File

@ -0,0 +1,36 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
import { SerializeFunction } from 'redux-remember';
const serializationDenylist: {
[key: string]: string[];
} = {
canvas: canvasPersistDenylist,
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
results: resultsPersistDenylist,
system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist,
uploads: uploadsPersistDenylist,
// hotkeys: hotkeysPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key]);
return JSON.stringify(result);
};

View File

@ -0,0 +1,38 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialResultsState } from 'features/gallery/store/resultsSlice';
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
import { defaultsDeep } from 'lodash-es';
import { UnserializeFunction } from 'redux-remember';
const initialStates: {
[key: string]: any;
} = {
canvas: initialCanvasState,
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
models: initialModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
results: initialResultsState,
system: initialSystemState,
config: initialConfigState,
ui: initialUIState,
uploads: initialUploadsState,
hotkeys: initialHotkeysState,
};
export const unserialize: UnserializeFunction = (data, key) => {
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
return result;
};

View File

@ -0,0 +1,30 @@
import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es';
import { Graph } from 'services/api';
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return {
...action,
payload: { ...action.payload, nodes: sanitizedNodes },
};
}
}
return action;
};

View File

@ -0,0 +1,11 @@
export const actionsDenylist = [
'canvas/setCursorPosition',
'canvas/setStageCoordinates',
'canvas/setStageScale',
'canvas/setIsDrawing',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
];

View File

@ -0,0 +1,3 @@
export const stateSanitizer = <S>(state: S): S => {
return state;
};

View File

@ -0,0 +1,54 @@
import {
createListenerMiddleware,
addListener,
ListenerEffect,
AnyAction,
} from '@reduxjs/toolkit';
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addImageResultReceivedListener } from './listeners/invocationComplete';
import { addImageUploadedListener } from './listeners/imageUploaded';
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const startAppListening =
listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener as TypedAddListener<
RootState,
AppDispatch
>;
export type AppListenerEffect = ListenerEffect<
AnyAction,
RootState,
AppDispatch
>;
addImageUploadedListener();
addInitialImageSelectedListener();
addImageResultReceivedListener();
addRequestedImageDeletionListener();
addUserInvokedCanvasListener();
addUserInvokedNodesListener();
addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener();
addCanvasSavedToGalleryListener();
addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener();
addCanvasMergedListener();

View File

@ -0,0 +1,33 @@
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard';
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
export const addCanvasCopiedToClipboardListener = () => {
startAppListening({
actionCreator: canvasCopiedToClipboard,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Copying Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
copyBlobToClipboard(blob);
},
});
};

View File

@ -0,0 +1,33 @@
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasDownloadedAsImageListener = () => {
startAppListening({
actionCreator: canvasDownloadedAsImage,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Downloading Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
downloadBlob(blob, 'mergedCanvas.png');
},
});
};

View File

@ -0,0 +1,88 @@
import { canvasMerged } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
export const addCanvasMergedListener = () => {
startAppListening({
actionCreator: canvasMerged,
effect: async (action, { dispatch, getState, take }) => {
const state = getState();
const blob = await getBaseLayerBlob(state, true);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Merging Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
const canvasBaseLayer = getCanvasBaseLayer();
if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer');
dispatch(
addToast({
title: 'Problem Merging Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
const baseLayerRect = canvasBaseLayer.getClientRect({
relativeTo: canvasBaseLayer.getParent(),
});
const filename = `mergedCanvas_${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([blob], filename, { type: 'image/png' }),
},
})
);
const [{ payload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === filename
);
const mergedCanvasImage = deserializeImageResponse(payload.response);
dispatch(
setMergedCanvas({
kind: 'image',
layer: 'base',
image: mergedCanvasImage,
...baseLayerRect,
})
);
dispatch(
addToast({
title: 'Canvas Merged',
status: 'success',
})
);
},
});
};

View File

@ -0,0 +1,40 @@
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasSavedToGalleryListener = () => {
startAppListening({
actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Saving Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
dispatch(
imageUploaded({
imageType: 'results',
formData: {
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
},
})
);
},
});
};

View File

@ -0,0 +1,59 @@
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..';
import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: requestedImageDeletion,
effect: (action, { dispatch, getState }) => {
const image = action.payload;
if (!image) {
moduleLog.warn('No image provided');
return;
}
const { name, type } = image;
if (type !== 'uploads' && type !== 'results') {
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
return;
}
const selectedImageName = getState().gallery.selectedImage?.name;
if (selectedImageName === name) {
const allIds = getState()[type].ids;
const allEntities = getState()[type].entities;
const deletedImageIndex = allIds.findIndex(
(result) => result.toString() === name
);
const filteredIds = allIds.filter((id) => id.toString() !== name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
0,
filteredIds.length - 1
);
const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = allEntities[newSelectedImageId];
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage));
} else {
dispatch(imageSelected());
}
}
dispatch(imageDeleted({ imageName: name, imageType: type }));
},
});
};

View File

@ -0,0 +1,46 @@
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { startAppListening } from '..';
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice';
import { initialImageSelected } from 'features/parameters/store/actions';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { resultAdded } from 'features/gallery/store/resultsSlice';
export const addImageUploadedListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.payload.response.image_type !== 'intermediates',
effect: (action, { dispatch, getState }) => {
const { response } = action.payload;
const { imageType } = action.meta.arg;
const state = getState();
const image = deserializeImageResponse(response);
if (imageType === 'uploads') {
dispatch(uploadAdded(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (action.meta.arg.activeTabName === 'img2img') {
dispatch(initialImageSelected(image));
}
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(image));
}
}
if (imageType === 'results') {
dispatch(resultAdded(image));
}
},
});
};

View File

@ -0,0 +1,54 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { Image, isInvokeAIImage } from 'app/types/invokeai';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { initialImageSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
export const addInitialImageSelectedListener = () => {
startAppListening({
actionCreator: initialImageSelected,
effect: (action, { getState, dispatch }) => {
if (!action.payload) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
if (isInvokeAIImage(action.payload)) {
dispatch(initialImageChanged(action.payload));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
return;
}
const { name, type } = action.payload;
let image: Image | undefined;
const state = getState();
if (type === 'results') {
image = selectResultsById(state, name);
} else if (type === 'uploads') {
image = selectUploadsById(state, name);
}
if (!image) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
dispatch(initialImageChanged(image));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
},
});
};

View File

@ -0,0 +1,88 @@
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { Image } from 'app/types/invokeai';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
import { startAppListening } from '..';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
const nodeDenylist = ['dataURL_image'];
export const addImageResultReceivedListener = () => {
startAppListening({
predicate: (action) => {
if (
invocationComplete.match(action) &&
isImageOutput(action.payload.data.result)
) {
return true;
}
return false;
},
effect: (action, { getState, dispatch }) => {
if (!invocationComplete.match(action)) {
return;
}
const { data, shouldFetchImages } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const name = result.image.image_name;
const type = result.image.image_type;
const state = getState();
// if we need to refetch, set URLs to placeholder for now
const { url, thumbnail } = shouldFetchImages
? { url: '', thumbnail: '' }
: buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width,
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
dispatch(resultAdded(image));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (state.config.shouldFetchImages) {
dispatch(imageReceived({ imageName: name, imageType: type }));
dispatch(
thumbnailReceived({
thumbnailName: name,
thumbnailType: type,
})
);
}
if (
graph_execution_state_id ===
state.canvas.layerState.stagingArea.sessionId
) {
dispatch(addImageToStagingArea(image));
}
}
},
});
};

View File

@ -0,0 +1,164 @@
import { startAppListening } from '..';
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
import { userInvoked } from 'app/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
const moduleLog = log.child({ namespace: 'invoke' });
/**
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
* It is also responsible for uploading the base and mask layers to the server.
*/
export const addUserInvokedCanvasListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'unifiedCanvas',
effect: async (action, { getState, dispatch, take }) => {
const state = getState();
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(state);
if (!canvasBlobsAndImageData) {
moduleLog.error('Unable to create canvas data');
return;
}
const { baseBlob, baseImageData, maskBlob, maskImageData } =
canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(
baseImageData,
maskImageData
);
if (state.system.enableImageDebugging) {
const baseDataURL = await blobToDataURL(baseBlob);
const maskDataURL = await blobToDataURL(maskBlob);
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask b64' },
{ base64: baseDataURL, caption: 'image b64' },
]);
}
moduleLog.debug(`Generation mode: ${generationMode}`);
// Build the canvas graph
const graphComponents = await buildCanvasGraphComponents(
state,
generationMode
);
if (!graphComponents) {
moduleLog.error('Problem building graph');
return;
}
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
// Upload the base layer, to be used as init image
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
// Upload the mask layer image
const maskFilename = `${uuidv4()}.png`;
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
};
const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built');
// Actually create the session
dispatch(sessionCreated({ graph }));
// Wait for the session to be invoked (this is just the HTTP request to start processing)
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedImageToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'img2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { sessionCreated } from 'services/thunks/session';
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedNodesListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph));
moduleLog({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedTextToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'txt2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildTextToImageGraph(state);
dispatch(textToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

Some files were not shown because too many files have changed in this diff Show More