Compare commits

...

38 Commits

Author SHA1 Message Date
054e963bef add basic autocomplete functionality to node cli (#3035)
- Commands, invocations and their parameters will now autocomplete using
introspection.
- Two types of parameter *arguments* will also autocomplete:
  - --sampler_name  will autocomplete the scheduler name
  - --model will autocomplete the model name
- There don't seem to be commands for reading/writing image files yet,
so path autocompletion is not implemented
2023-03-30 08:25:36 -04:00
afb66a7884 Merge branch 'main' into feat/node-cli-autocompleter 2023-03-30 07:51:51 -04:00
25ae36ceb5 I18n build mode (#3051)
Add build mode option to bundle english translation with UI
2023-03-29 22:26:45 -04:00
3ae8daedaa Merge branch 'main' into i18n-build-mode 2023-03-29 22:26:17 -04:00
b913e1e11e improve importation and conversion of legacy checkpoint files (#3053)
A long-standing issue with importing legacy checkpoints (both ckpt and
safetensors) is that the user has to identify the correct config file,
either by providing its path or by selecting which type of model the
checkpoint is (e.g. "v1 inpainting"). In addition, some users wish to
provide custom VAEs for use with the model. Currently this is done in
the WebUI by importing the model, editing it, and then typing in the
path to the VAE.

## Model configuration file selection

To improve the user experience, the model manager's `heuristic_import()`
method has been enhanced as follows:

1. When initially called, the caller can pass a config file path, in
which case it will be used.

2. If no config file provided, the method looks for a .yaml file in the
same directory as the model which bears the same basename. e.g.
```
   my-new-model.safetensors
   my-new-model.yaml
```
The yaml file is then used as the configuration file for importation and
conversion.

3. If no such file is found, then the method opens up the checkpoint and
probes it to determine whether it is V1, V1-inpaint or V2. If it is a V1
format, then the appropriate v1-inference.yaml config file is used.
Unfortunately there are two V2 variants that cannot be distinguished by
introspection.

4. If the probe algorithm is unable to determine the model type, then
its last-ditch effort is to execute an optional callback function that
can be provided by the caller. This callback, named
`config_file_callback` receives the path to the legacy checkpoint and
returns the path to the config file to use. The CLI uses to put up a
multiple choice prompt to the user. The WebUI **could** use this to
prompt the user to choose from a radio-button selection.

5. If the config file cannot be determined, then the import is
abandoned.

## Custom VAE Selection

The user can attach a custom VAE to the imported and converted model by
copying the desired VAE into the same directory as the file to be
imported, and giving it the same basename. E.g.:

```
    my-new-model.safetensors
    my-new-model.vae.pt
```

For this to work, the VAE must end with ".vae.pt", ".vae.ckpt", or
".vae.safetensors". The indicated VAE will be converted into diffusers
format and stored with the converted models file, so the ".pt" file can
be deleted after conversion.

No facility is currently provided to swap a diffusers VAE at import
time, but this can be done after the fact using the WebUI and CLI's
model editing functions.

Note that this is the same fix that was applied to the 2.3 branch in
#3043 . This applies to `main`.
2023-03-29 17:22:15 -04:00
3c4b6d5735 Merge branch 'main' into enhance/heuristic-import-improvements 2023-03-29 16:54:43 -04:00
e6123eac19 Merge branch 'main' into i18n-build-mode 2023-03-29 05:33:14 -07:00
30ca25897e Fix bugs in online ckpt conversion of 2.0 models (#3057)
## Enable the on-the-fly conversion of models based on SD 2.0/2.1 into
diffusers

This commit fixes bugs related to the on-the-fly conversion and loading
of legacy checkpoint models built on SD-2.0 base.

- When legacy checkpoints built on SD-2.0 models were converted
on-the-fly using --ckpt_convert, generation would crash with a precision
incompatibility error. This problem has been found and fixed.
2023-03-28 23:34:53 -04:00
abaee6b9ed Merge branch 'main' into feat/node-cli-autocompleter 2023-03-28 23:32:10 -04:00
4d7c9e1ab7 Merge branch 'main' into bugfix/convert-2.0-models 2023-03-28 23:01:36 -04:00
cc5687f26c [nodes] downgrade fastapi+uvicorn to fix openapi schema 2023-03-28 22:53:20 -04:00
78e76f26f9 Merge branch 'main' into i18n-build-mode 2023-03-28 11:04:32 -04:00
9a7580dedd fix bugs in online ckpt conversion of 2.0 models
This commit fixes bugs related to the on-the-fly conversion and loading of
legacy checkpoint models built on SD-2.0 base.

- When legacy checkpoints built on SD-2.0 models were converted
  on-the-fly using --ckpt_convert, generation would crash with a
  precision incompatibility error.
2023-03-28 00:17:20 -04:00
dc2da8cff4 Doc: updating ROCm version in documentation (#3041)
The Pytorch ROCm version in the documentation in outdated (`rocm5.2`)
which leads to errors during the installation of InvokeAI.

This PR updates the documentation with the latest Pytorch ROCm `5.4.2`
version.
2023-03-27 22:37:43 -04:00
019a9f0329 address change requests in PR
1. Prompt has changed to "invoke> ".
2. Function to initialize the autocompleter has been renamed "set_autocompleter()"
2023-03-27 12:20:24 -04:00
fe5d9ad171 improve importation and conversion of legacy checkpoint files
A long-standing issue with importing legacy checkpoints (both ckpt and
safetensors) is that the user has to identify the correct config file,
either by providing its path or by selecting which type of model the
checkpoint is (e.g. "v1 inpainting"). In addition, some users wish to
provide custom VAEs for use with the model. Currently this is done in
the WebUI by importing the model, editing it, and then typing in the
path to the VAE.

To improve the user experience, the model manager's
`heuristic_import()` method has been enhanced as follows:

1. When initially called, the caller can pass a config file path, in
which case it will be used.

2. If no config file provided, the method looks for a .yaml file in the
same directory as the model which bears the same basename. e.g.
```
   my-new-model.safetensors
   my-new-model.yaml
```
   The yaml file is then used as the configuration file for
   importation and conversion.

3. If no such file is found, then the method opens up the checkpoint
   and probes it to determine whether it is V1, V1-inpaint or V2.
   If it is a V1 format, then the appropriate v1-inference.yaml config
   file is used. Unfortunately there are two V2 variants that cannot be
   distinguished by introspection.

4. If the probe algorithm is unable to determine the model type, then its
   last-ditch effort is to execute an optional callback function that can
   be provided by the caller. This callback, named `config_file_callback`
   receives the path to the legacy checkpoint and returns the path to the
   config file to use. The CLI uses to put up a multiple choice prompt to
   the user. The WebUI **could** use this to prompt the user to choose
   from a radio-button selection.

5. If the config file cannot be determined, then the import is abandoned.

The user can attach a custom VAE to the imported and converted model
by copying the desired VAE into the same directory as the file to be
imported, and giving it the same basename. E.g.:

```
    my-new-model.safetensors
    my-new-model.vae.pt
```

For this to work, the VAE must end with ".vae.pt", ".vae.ckpt", or
".vae.safetensors". The indicated VAE will be converted into diffusers
format and stored with the converted models file, so the ".pt" file
can be deleted after conversion.

No facility is currently provided to swap a diffusers VAE at import
time, but this can be done after the fact using the WebUI and CLI's
model editing functions.
2023-03-27 11:27:45 -04:00
dbc0093b31 Merge remote-tracking branch 'origin' into i18n-build-mode 2023-03-27 10:57:41 -04:00
92e512b8b6 add package mode option for i18next 2023-03-27 10:49:52 -04:00
dc14701d20 Merge branch 'main' into feat/node-cli-autocompleter 2023-03-26 23:46:10 -04:00
737e0f3085 doc: fixing error in rocm version 2023-03-26 12:40:20 +02:00
81b7ea4362 doc: updating ROCm version for pip install 2023-03-26 12:32:12 +02:00
09dfde0ba1 fix(ui): fix viewer tooltip localisation strings (#3037)
fixes #2923
2023-03-26 20:35:52 +13:00
3ba7e966b5 Merge branch 'main' into fix/ui/viewer-localisation 2023-03-26 20:35:12 +13:00
a1cd4834d1 nodes: add cancelation, updated progress callback, typing fixes (#3036)
keeping `main` up to date with my api nodes branch:
- bd7e515290: [nodes] Add cancelation to
the API @Kyle0654
- 5fe38f7: fix(backend): simple typing fixes
  - just picking some low-hanging fruit to improve IDE hinting
- c34ac91: fix(nodes): fix cancel; fix callback for img2img, inpaint
- makes nodes cancel immediate, use fix progress images on nodes, fix
callbacks for img2img/inpaint
- 4221cf7: fix(nodes): fix schema generation for output classes
- did this previously for some other class; needed to not have node
outputs be optional
2023-03-26 20:34:27 +13:00
a724038dc6 fix(ui): fix viewer tooltip localisation strings
fixes #2923
2023-03-26 17:43:00 +11:00
4221cf7731 fix(nodes): fix schema generation for output classes
All output classes need to have their properties flagged as `required` for the schema generation to work as needed.
2023-03-26 17:20:10 +11:00
c34ac91ff0 fix(nodes): fix cancel; fix callback for img2img, inpaint 2023-03-26 17:07:40 +11:00
5fe38f7c88 fix(backend): simple typing fixes 2023-03-26 17:07:03 +11:00
bd7e515290 [nodes] Add cancelation to the API 2023-03-26 15:47:32 +11:00
076fac07eb feat[web]: use the predicted denoised image for previews (#2915)
Some schedulers report not only the noisy latents at the current
timestep, but also their estimate so far of what the de-noised latents
will be.

It makes for a more legible preview than the noisy latents do.

I think this is a huge improvement, but there are a few considerations:
- Need to not spook @JPPhoto by changing how previews look.
- Some schedulers (most notably **DPM Solver++**) don't provide this
data, and it falls back to the current behavior there. That's not
terrible, but seeing such a big difference in how _previews_ look from
one scheduler to the next might mislead people into thinking there's a
bigger difference in their overall effectiveness than there really is.

My fear of configuration-option-overwhelm leaves me inclined to _not_
add a configuration option for this, but we could.
2023-03-26 00:29:00 -04:00
9348161600 add basic autocomplete functionality to node cli
- Commands, invocations and their parameters will now autocomplete
  using introspection.
- Two types of parameter *arguments* will also autocomplete:
  - --sampler_name  will autocomplete the scheduler name
  - --model will autocomplete the model name
- There don't seem to be commands for reading/writing image files yet, so
  path autocompletion is not implemented
2023-03-26 00:24:27 -04:00
dac3c158a5 Merge branch 'main' into feat/preview_predicted_x0
- resolve conflicts with generate.py invocation
- remove unused symbols that pyflakes complains about
- add **untested** code for passing intermediate latent image to the
  step callback in the format expected.
2023-03-25 16:07:18 -04:00
288cee9611 Merge remote-tracking branch 'origin/main' into feat/preview_predicted_x0
# Conflicts:
#	invokeai/app/invocations/generate.py
2023-03-12 20:56:02 -07:00
06aa5a8120 Merge branch 'main' into feat/preview_predicted_x0 2023-03-11 14:50:30 -06:00
f45483e519 Merge branch 'main' into feat/preview_predicted_x0 2023-03-10 22:25:26 -06:00
63f59201f8 Merge branch 'main' into feat/preview_predicted_x0 2023-03-10 12:34:07 -06:00
4a00f1cc74 Merge branch 'main' into feat/preview_predicted_x0 2023-03-10 09:20:01 -06:00
fe6858f2d9 feat: use the predicted denoised image for previews
Some schedulers report not only the noisy latents at the current timestep,
but also their estimate so far of what the de-noised latents will be.

It makes for a more legible preview than the noisy latents do.
2023-03-09 20:28:06 -08:00
23 changed files with 548 additions and 164 deletions

View File

@ -145,7 +145,7 @@ not supported.
_For Linux with an AMD GPU:_ _For Linux with an AMD GPU:_
```sh ```sh
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
_For Macintoshes, either Intel or M1/M2:_ _For Macintoshes, either Intel or M1/M2:_

View File

@ -417,7 +417,7 @@ Then type the following commands:
=== "AMD System" === "AMD System"
```bash ```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
### Corrupted configuration file ### Corrupted configuration file

View File

@ -154,7 +154,7 @@ manager, please follow these steps:
=== "ROCm (AMD)" === "ROCm (AMD)"
```bash ```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
=== "CPU (Intel Macs & non-GPU systems)" === "CPU (Intel Macs & non-GPU systems)"
@ -315,7 +315,7 @@ installation protocol (important!)
=== "ROCm (AMD)" === "ROCm (AMD)"
```bash ```bash
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
=== "CPU (Intel Macs & non-GPU systems)" === "CPU (Intel Macs & non-GPU systems)"

View File

@ -110,7 +110,7 @@ recipes are available
When installing torch and torchvision manually with `pip`, remember to provide When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url the argument `--extra-index-url
https://download.pytorch.org/whl/rocm5.2` as described in the [Manual https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md). Installation Guide](020_INSTALL_MANUAL.md).
This will be done automatically for you if you use the installer This will be done automatically for you if you use the installer

View File

@ -270,3 +270,18 @@ async def invoke_session(
ApiDependencies.invoker.invoke(session, invoke_all=all) ApiDependencies.invoker.invoke(session, invoke_all=all)
return Response(status_code=202) return Response(status_code=202)
@session_router.delete(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
responses={
202: {"description": "The invocation is canceled"}
},
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),
) -> None:
"""Invokes a session"""
ApiDependencies.invoker.cancel(session_id)
return Response(status_code=202)

View File

@ -0,0 +1,167 @@
"""
Readline helper functions for cli_app.py
You may import the global singleton `completer` to get access to the
completer object.
"""
import atexit
import readline
import shlex
from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
# singleton object, class variable
completer = None
class Completer(object):
def __init__(self, model_manager: ModelManager):
self.commands = self.get_commands()
self.matches = None
self.linebuffer = None
self.manager = model_manager
return
def complete(self, text, state):
"""
Complete commands and switches fromm the node CLI command line.
Switches are determined in a context-specific manner.
"""
buffer = readline.get_line_buffer()
if state == 0:
options = None
try:
current_command, current_switch = self.get_current_command(buffer)
options = self.get_command_options(current_command, current_switch)
except IndexError:
pass
options = options or list(self.parse_commands().keys())
if not text: # first time
self.matches = options
else:
self.matches = [s for s in options if s and s.startswith(text)]
try:
match = self.matches[state]
except IndexError:
match = None
return match
@classmethod
def get_commands(self)->List[object]:
"""
Return a list of all the client commands and invocations.
"""
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
def get_current_command(self, buffer: str)->tuple[str, str]:
"""
Parse the readline buffer to find the most recent command and its switch.
"""
if len(buffer)==0:
return None, None
tokens = shlex.split(buffer)
command = None
switch = None
for t in tokens:
if t[0].isalpha():
if switch is None:
command = t
else:
switch = t
# don't try to autocomplete switches that are already complete
if switch and buffer.endswith(' '):
switch=None
return command or '', switch or ''
def parse_commands(self)->Dict[str, List[str]]:
"""
Return a dict in which the keys are the command name
and the values are the parameters the command takes.
"""
result = dict()
for command in self.commands:
hints = get_type_hints(command)
name = get_args(hints['type'])[0]
result.update({name:hints})
return result
def get_command_options(self, command: str, switch: str)->List[str]:
"""
Return all the parameters that can be passed to the command as
command-line switches. Returns None if the command is unrecognized.
"""
parsed_commands = self.parse_commands()
if command not in parsed_commands:
return None
# handle switches in the format "-foo=bar"
argument = None
if switch and '=' in switch:
switch, argument = switch.split('=')
parameter = switch.strip('-')
if parameter in parsed_commands[command]:
if argument is None:
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
else:
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
else:
return [f"--{x}" for x in parsed_commands[command].keys()]
def get_parameter_options(self, parameter: str, typehint)->List[str]:
"""
Given a parameter type (such as Literal), offers autocompletions.
"""
if get_origin(typehint) == Literal:
return get_args(typehint)
if parameter == 'model':
return self.manager.model_names()
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def set_autocompleter(model_manager: ModelManager) -> Completer:
global completer
if completer:
return completer
completer = Completer(model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(True)
except:
pass
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(Globals.root, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)

View File

@ -14,6 +14,7 @@ from pydantic.fields import Field
from ..backend import Args from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
@ -130,6 +131,12 @@ def invoke_cli():
config.parse_args() config.parse_args()
model_manager = get_model_manager(config) model_manager = get_model_manager(config)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)
events = EventServiceBase() events = EventServiceBase()
output_folder = os.path.abspath( output_folder = os.path.abspath(
@ -162,8 +169,8 @@ def invoke_cli():
while True: while True:
try: try:
cmd_input = input("> ") cmd_input = input("invoke> ")
except KeyboardInterrupt: except (KeyboardInterrupt, EOFError):
# Ctrl-c exits # Ctrl-c exits
break break

View File

@ -1,22 +1,19 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone from functools import partial
from typing import Any, Literal, Optional, Union from typing import Literal, Optional, Union
import numpy as np import numpy as np
from torch import Tensor from torch import Tensor
from PIL import Image
from pydantic import Field from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL from ..util.util import diffusers_step_callback_adapter, CanceledException
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers()) tuple(InvokeAIGenerator.schedulers())
@ -45,32 +42,26 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None: ) -> None:
# TODO: only output a preview image when requested if (context.services.queue.is_canceled(context.graph_execution_state_id)):
image = Generator.sample_to_lowres_estimated_image(sample) raise CanceledException
(width, height) = image.size step = intermediate_state.step
width *= 8 if intermediate_state.predicted_original is not None:
height *= 8 # Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
dataURL = image_to_dataURL(image, image_format="JPEG") diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(state: PipelineIntermediateState): # def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step) # if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
@ -79,7 +70,7 @@ class TextToImageInvocation(BaseInvocation):
model= context.services.model_manager.get_model() model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate( outputs = Txt2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt"} exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
@ -116,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image", description="Whether or not the result should be fit to the aspect ratio of the input image",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -126,24 +133,23 @@ class ImageToImageInvocation(TextToImageInvocation):
) )
mask = None mask = None
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
model = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Img2Img(model).generate(
Img2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_image=image,
init_mask=mask, init_mask=mask,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image
@ -173,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise", description="The amount by which to replace masked areas with latent noise",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -187,24 +209,23 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get(self.mask.image_type, self.mask.image_name) else context.services.images.get(self.mask.image_type, self.mask.image_name)
) )
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
manager = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Inpaint(model).generate(
Inpaint(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_img=image,
mask_image=mask, init_mask=mask,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image

View File

@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
class MaskOutput(BaseInvocationOutput): class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask""" """Base class for invocations that output a mask"""
#fmt: off #fmt: off
type: Literal["mask"] = "mask" type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask") mask: ImageField = Field(default=None, description="The output mask")
#fomt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
]
}
# TODO: this isn't really necessary anymore # TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):

View File

@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
prompt: str = Field(default=None, description="The output prompt") prompt: str = Field(default=None, description="The output prompt")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'prompt',
]
}

View File

@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
class GraphInvocationOutput(BaseInvocationOutput): class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output" type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation): class GraphInvocation(BaseInvocation):
@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = Field(description="The item being iterated over") item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation): class IterateInvocation(BaseInvocation):
@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = Field(description="The collection of input items") collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation): class CollectInvocation(BaseInvocation):
"""Collects values into a collection""" """Collects values into a collection"""

View File

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
import time
# TODO: make this serializable # TODO: make this serializable
@ -10,6 +11,7 @@ class InvocationQueueItem:
graph_execution_state_id: str graph_execution_state_id: str
invocation_id: str invocation_id: str
invoke_all: bool invoke_all: bool
timestamp: float
def __init__( def __init__(
self, self,
@ -22,6 +24,7 @@ class InvocationQueueItem:
self.graph_execution_state_id = graph_execution_state_id self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id self.invocation_id = invocation_id
self.invoke_all = invoke_all self.invoke_all = invoke_all
self.timestamp = time.time()
class InvocationQueueABC(ABC): class InvocationQueueABC(ABC):
@ -35,15 +38,44 @@ class InvocationQueueABC(ABC):
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
pass pass
@abstractmethod
def cancel(self, graph_execution_state_id: str) -> None:
pass
@abstractmethod
def is_canceled(self, graph_execution_state_id: str) -> bool:
pass
class MemoryInvocationQueue(InvocationQueueABC): class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue __queue: Queue
__cancellations: dict[str, float]
def __init__(self): def __init__(self):
self.__queue = Queue() self.__queue = Queue()
self.__cancellations = dict()
def get(self) -> InvocationQueueItem: def get(self) -> InvocationQueueItem:
return self.__queue.get() item = self.__queue.get()
while isinstance(item, InvocationQueueItem) \
and item.graph_execution_state_id in self.__cancellations \
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
item = self.__queue.get()
# Clear old items
for graph_execution_state_id in list(self.__cancellations.keys()):
if self.__cancellations[graph_execution_state_id] < item.timestamp:
del self.__cancellations[graph_execution_state_id]
return item
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
self.__queue.put(item) self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:
if graph_execution_state_id not in self.__cancellations:
self.__cancellations[graph_execution_state_id] = time.time()
def is_canceled(self, graph_execution_state_id: str) -> bool:
return graph_execution_state_id in self.__cancellations

View File

@ -51,6 +51,10 @@ class Invoker:
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
def __start_service(self, service) -> None: def __start_service(self, service) -> None:
# Call start() method on any services that have it # Call start() method on any services that have it
start_op = getattr(service, "start", None) start_op = getattr(service, "start", None)

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
) )
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Save outputs and history # Save outputs and history
graph_execution_state.complete(invocation.id, outputs) graph_execution_state.complete(invocation.id, outputs)
@ -76,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except CanceledException:
pass
except Exception as e: except Exception as e:
error = traceback.format_exc() error = traceback.format_exc()
@ -96,6 +105,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Queue any further commands if invoking all # Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete() is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete: if queue_item.invoke_all and not is_complete:

42
invokeai/app/util/util.py Normal file
View File

@ -0,0 +1,42 @@
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import List, Iterator, Type from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,23 +35,23 @@ downsampling = 8
@dataclass @dataclass
class InvokeAIGeneratorBasicParams: class InvokeAIGeneratorBasicParams:
seed: int=None seed: Optional[int]=None
width: int=512 width: int=512
height: int=512 height: int=512
cfg_scale: int=7.5 cfg_scale: float=7.5
steps: int=20 steps: int=20
ddim_eta: float=0.0 ddim_eta: float=0.0
scheduler: int='ddim' scheduler: str='ddim'
precision: str='float16' precision: str='float16'
perlin: float=0.0 perlin: float=0.0
threshold: int=0.0 threshold: float=0.0
seamless: bool=False seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: float=None v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0 variation_amount: float = 0.0
with_variations: list=field(default_factory=list) with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None safety_checker: Optional[SafetyChecker]=None
@dataclass @dataclass
class InvokeAIGeneratorOutput: class InvokeAIGeneratorOutput:
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
and the model hash, as well as all the generate() parameters that went into and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes) generating the image (in .params, also available as attributes)
''' '''
image: Image image: Image.Image
seed: int seed: int
model_hash: str model_hash: str
attention_maps_images: List[Image] attention_maps_images: List[Image.Image]
params: Namespace params: Namespace
# we are interposing a wrapper around the original Generator classes so that # we are interposing a wrapper around the original Generator classes so that
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def generate(self, def generate(self,
prompt: str='', prompt: str='',
callback: callable=None, callback: Optional[Callable]=None,
step_callback: callable=None, step_callback: Optional[Callable]=None,
iterations: int=1, iterations: int=1,
**keyword_args, **keyword_args,
)->Iterator[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(self,
init_image: Image | torch.FloatTensor, init_image: Image.Image | torch.FloatTensor,
strength: float=0.75, strength: float=0.75,
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image, return super().generate(init_image=init_image,
strength=strength, strength=strength,
**keyword_args **keyword_args
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image | torch.FloatTensor, mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 0,
seam_blur: int = 0, seam_blur: int = 0,
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
inpaint_height=None, inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate( return super().generate(
mask_image=mask_image, mask_image=mask_image,
seam_size=seam_size, seam_size=seam_size,
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
embiggen: list=None, embiggen: list=None,
embiggen_tiles: list = None, embiggen_tiles: list = None,
strength: float=0.75, strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]: **kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen, return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
strength=strength, strength=strength,

View File

@ -1264,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
cache_dir=cache_dir, cache_dir=cache_dir,
) )
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,

View File

@ -18,7 +18,7 @@ import warnings
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from shutil import move, rmtree from shutil import move, rmtree
from typing import Any, Optional, Union from typing import Any, Optional, Union, Callable
import safetensors import safetensors
import safetensors.torch import safetensors.torch
@ -630,14 +630,13 @@ class ModelManager(object):
def heuristic_import( def heuristic_import(
self, self,
path_url_or_repo: str, path_url_or_repo: str,
convert: bool = True,
model_name: str = None, model_name: str = None,
description: str = None, description: str = None,
model_config_file: Path = None, model_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
) -> str: ) -> str:
""" """Accept a string which could be:
Accept a string which could be:
- a HF diffusers repo_id - a HF diffusers repo_id
- a URL pointing to a legacy .ckpt or .safetensors file - a URL pointing to a legacy .ckpt or .safetensors file
- a local path pointing to a legacy .ckpt or .safetensors file - a local path pointing to a legacy .ckpt or .safetensors file
@ -651,16 +650,20 @@ class ModelManager(object):
The model_name and/or description can be provided. If not, they will The model_name and/or description can be provided. If not, they will
be generated automatically. be generated automatically.
If convert is true, legacy models will be converted to diffusers
before importing.
If commit_to_conf is provided, the newly loaded model will be written If commit_to_conf is provided, the newly loaded model will be written
to the `models.yaml` file at the indicated path. Otherwise, the changes to the `models.yaml` file at the indicated path. Otherwise, the changes
will only remain in memory. will only remain in memory.
The (potentially derived) name of the model is returned on success, or None The routine will do its best to figure out the config file
on failure. When multiple models are added from a directory, only the last needed to convert legacy checkpoint file, but if it can't it
imported one is returned. will call the config_file_callback routine, if provided. The
callback accepts a single argument, the Path to the checkpoint
file, and returns a Path to the config file to use.
The (potentially derived) name of the model is returned on
success, or None on failure. When multiple models are added
from a directory, only the last imported one is returned.
""" """
model_path: Path = None model_path: Path = None
thing = path_url_or_repo # to save typing thing = path_url_or_repo # to save typing
@ -707,7 +710,7 @@ class ModelManager(object):
Path(thing).rglob("*.safetensors") Path(thing).rglob("*.safetensors")
): ):
if model_name := self.heuristic_import( if model_name := self.heuristic_import(
str(m), convert, commit_to_conf=commit_to_conf str(m), commit_to_conf=commit_to_conf
): ):
print(f" >> {model_name} successfully imported") print(f" >> {model_name} successfully imported")
return model_name return model_name
@ -735,7 +738,7 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
checkpoint = None checkpoint = None
if model_path.suffix.endswith((".ckpt",".pt")): if model_path.suffix in [".ckpt",".pt"]:
self.scan_model(model_path,model_path) self.scan_model(model_path,model_path)
checkpoint = torch.load(model_path) checkpoint = torch.load(model_path)
else: else:
@ -743,6 +746,12 @@ class ModelManager(object):
# additional probing needed if no config file provided # additional probing needed if no config file provided
if model_config_file is None: if model_config_file is None:
# look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml")
print(f" | Using config file {model_config_file.name}")
else:
model_type = self.probe_model_type(checkpoint) model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1: if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected") print(" | SD-v1 model detected")
@ -756,20 +765,18 @@ class ModelManager(object):
) )
elif model_type == SDLegacyType.V2_v: elif model_type == SDLegacyType.V2_v:
print( print(
" | SD-v2-v model detected; model will be converted to diffusers format" " | SD-v2-v model detected"
) )
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
) )
convert = True
elif model_type == SDLegacyType.V2_e: elif model_type == SDLegacyType.V2_e:
print( print(
" | SD-v2-e model detected; model will be converted to diffusers format" " | SD-v2-e model detected"
) )
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml" Globals.root, "configs/stable-diffusion/v2-inference.yaml"
) )
convert = True
elif model_type == SDLegacyType.V2: elif model_type == SDLegacyType.V2:
print( print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
@ -781,13 +788,29 @@ class ModelManager(object):
) )
return return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
# despite our best efforts, we could not find a model config file, so give up
if not model_config_file:
return
# look for a custom vae, a like-named file ending with .vae in the same directory
vae_path = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}")
print(f" | Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path( diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
) )
model_name = self.convert_and_import( model_name = self.convert_and_import(
model_path, model_path,
diffusers_path=diffuser_path, diffusers_path=diffuser_path,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), vae=vae,
vae_path=str(vae_path),
model_name=model_name, model_name=model_name,
model_description=description, model_description=description,
original_config_file=model_config_file, original_config_file=model_config_file,
@ -829,8 +852,8 @@ class ModelManager(object):
return return
model_name = model_name or diffusers_path.name model_name = model_name or diffusers_path.name
model_description = model_description or f"Optimized version of {model_name}" model_description = model_description or f"Converted version of {model_name}"
print(f">> Optimizing {model_name} (30-60s)") print(f" | Converting {model_name} to diffusers (30-60s)")
try: try:
# By passing the specified VAE to the conversion function, the autoencoder # By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file # will be built into the model rather than tacked on afterward via the config file
@ -848,7 +871,7 @@ class ModelManager(object):
scan_needed=scan_needed, scan_needed=scan_needed,
) )
print( print(
f" | Success. Optimized model is now located at {str(diffusers_path)}" f" | Success. Converted model is now located at {str(diffusers_path)}"
) )
print(f" | Writing new config file entry for {model_name}") print(f" | Writing new config file entry for {model_name}")
new_config = dict( new_config = dict(

View File

@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
"RGB" "RGB"
) )
def image_progress(sample, step): def image_progress(intermediate_state: PipelineIntermediateState):
if self.canceled.is_set(): if self.canceled.is_set():
raise CanceledException raise CanceledException
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
nonlocal generation_parameters nonlocal generation_parameters
nonlocal progress nonlocal progress
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
generation_messages = { generation_messages = {
"txt2img": "common.statusGeneratingTextToImage", "txt2img": "common.statusGeneratingTextToImage",
"img2img": "common.statusGeneratingImageToImage", "img2img": "common.statusGeneratingImageToImage",
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
progress.set_current_iteration(progress.current_iteration + 1) progress.set_current_iteration(progress.current_iteration + 1)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return image_progress(progress_state.latents, progress_state.step)
else:
return image_progress(*cb_args, **kwargs)
self.generate.prompt2image( self.generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=diffusers_step_callback_adapter, step_callback=image_progress,
image_callback=image_done, image_callback=image_done,
) )

View File

@ -626,7 +626,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
completer.set_default_dir(opt.outdir) completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer, convert=False): def import_model(model_path: str, gen, opt, completer):
""" """
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a (3) a huggingface repository id; or (4) a local directory containing a
@ -657,7 +657,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path, model_path,
model_name=model_name, model_name=model_name,
description=model_desc, description=model_desc,
convert=convert,
) )
if not imported_name: if not imported_name:
@ -666,7 +665,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path, model_path,
model_name=model_name, model_name=model_name,
description=model_desc, description=model_desc,
convert=convert,
model_config_file=config_file, model_config_file=config_file,
) )
if not imported_name: if not imported_name:
@ -757,7 +755,6 @@ def _get_model_name_and_desc(
) )
return model_name, model_description return model_name, model_description
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace("\\", "/") # windows model_name_or_path = model_name_or_path.replace("\\", "/") # windows
manager = gen.model_manager manager = gen.model_manager
@ -788,7 +785,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
) )
else: else:
try: try:
import_model(model_name_or_path, gen, opt, completer, convert=True) import_model(model_name_or_path, gen, opt, completer)
except KeyboardInterrupt: except KeyboardInterrupt:
return return

View File

@ -34,7 +34,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiZoomIn />} icon={<BiZoomIn />}
aria-label={t('accessibility.zoomIn')} aria-label={t('accessibility.zoomIn')}
tooltip="Zoom In" tooltip={t('accessibility.zoomIn')}
onClick={() => zoomIn()} onClick={() => zoomIn()}
fontSize={20} fontSize={20}
/> />
@ -42,7 +42,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiZoomOut />} icon={<BiZoomOut />}
aria-label={t('accessibility.zoomOut')} aria-label={t('accessibility.zoomOut')}
tooltip="Zoom Out" tooltip={t('accessibility.zoomOut')}
onClick={() => zoomOut()} onClick={() => zoomOut()}
fontSize={20} fontSize={20}
/> />
@ -50,7 +50,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiRotateLeft />} icon={<BiRotateLeft />}
aria-label={t('accessibility.rotateCounterClockwise')} aria-label={t('accessibility.rotateCounterClockwise')}
tooltip="Rotate Counter-Clockwise" tooltip={t('accessibility.rotateCounterClockwise')}
onClick={rotateCounterClockwise} onClick={rotateCounterClockwise}
fontSize={20} fontSize={20}
/> />
@ -58,7 +58,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiRotateRight />} icon={<BiRotateRight />}
aria-label={t('accessibility.rotateClockwise')} aria-label={t('accessibility.rotateClockwise')}
tooltip="Rotate Clockwise" tooltip={t('accessibility.rotateClockwise')}
onClick={rotateClockwise} onClick={rotateClockwise}
fontSize={20} fontSize={20}
/> />
@ -66,7 +66,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<MdFlip />} icon={<MdFlip />}
aria-label={t('accessibility.flipHorizontally')} aria-label={t('accessibility.flipHorizontally')}
tooltip="Flip Horizontally" tooltip={t('accessibility.flipHorizontally')}
onClick={flipHorizontally} onClick={flipHorizontally}
fontSize={20} fontSize={20}
/> />
@ -74,7 +74,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<MdFlip style={{ transform: 'rotate(90deg)' }} />} icon={<MdFlip style={{ transform: 'rotate(90deg)' }} />}
aria-label={t('accessibility.flipVertically')} aria-label={t('accessibility.flipVertically')}
tooltip="Flip Vertically" tooltip={t('accessibility.flipVertically')}
onClick={flipVertically} onClick={flipVertically}
fontSize={20} fontSize={20}
/> />
@ -82,7 +82,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiReset />} icon={<BiReset />}
aria-label={t('accessibility.reset')} aria-label={t('accessibility.reset')}
tooltip="Reset" tooltip={t('accessibility.reset')}
onClick={() => { onClick={() => {
resetTransform(); resetTransform();
reset(); reset();

View File

@ -1,9 +1,24 @@
import i18n from 'i18next'; import i18n from 'i18next';
import LanguageDetector from 'i18next-browser-languagedetector'; import LanguageDetector from 'i18next-browser-languagedetector';
import Backend from 'i18next-http-backend'; import Backend from 'i18next-http-backend';
import { initReactI18next } from 'react-i18next'; import { initReactI18next } from 'react-i18next';
i18n
import translationEN from '../dist/locales/en.json';
if (import.meta.env.MODE === 'package') {
i18n.use(initReactI18next).init({
lng: 'en',
resources: {
en: { translation: translationEN },
},
debug: false,
interpolation: {
escapeValue: false,
},
returnNull: false,
});
} else {
i18n
.use(Backend) .use(Backend)
.use(LanguageDetector) .use(LanguageDetector)
.use(initReactI18next) .use(initReactI18next)
@ -18,5 +33,6 @@ i18n
}, },
returnNull: false, returnNull: false,
}); });
}
export default i18n; export default i18n;

View File

@ -45,7 +45,7 @@ dependencies = [
"einops", "einops",
"eventlet", "eventlet",
"facexlib", "facexlib",
"fastapi==0.94.1", "fastapi==0.88.0",
"fastapi-events==0.8.0", "fastapi-events==0.8.0",
"fastapi-socketio==0.0.10", "fastapi-socketio==0.0.10",
"flask==2.1.3", "flask==2.1.3",
@ -160,4 +160,3 @@ output = "coverage/index.xml"
[flake8] [flake8]
max-line-length = 120 max-line-length = 120