mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/batch-graphs
This commit is contained in:
commit
4b6d9a73ed
8
.github/workflows/style-checks.yml
vendored
8
.github/workflows/style-checks.yml
vendored
@ -1,6 +1,6 @@
|
||||
name: style checks
|
||||
# just formatting for now
|
||||
# TODO: add isort and flake8 later
|
||||
# just formatting and flake8 for now
|
||||
# TODO: add isort later
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
@ -20,8 +20,8 @@ jobs:
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install black
|
||||
pip install black flake8 Flake8-pyproject
|
||||
|
||||
# - run: isort --check-only .
|
||||
- run: black --check .
|
||||
# - run: flake8
|
||||
- run: flake8
|
||||
|
@ -8,3 +8,10 @@ repos:
|
||||
language: system
|
||||
entry: black
|
||||
types: [python]
|
||||
|
||||
- id: flake8
|
||||
name: flake8
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: flake8
|
||||
types: [python]
|
||||
|
@ -407,7 +407,7 @@ def get_pip_from_venv(venv_path: Path) -> str:
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
pip = "Scripts\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
return str(venv_path.expanduser().resolve() / pip)
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
inst.install(**args.__dict__)
|
||||
except KeyboardInterrupt as exc:
|
||||
except KeyboardInterrupt:
|
||||
print("\n")
|
||||
print("Ctrl-C pressed. Aborting.")
|
||||
print("Come back soon!")
|
||||
|
@ -70,7 +70,7 @@ def confirm_install(dest: Path) -> bool:
|
||||
)
|
||||
else:
|
||||
print(f"InvokeAI will be installed in {dest}")
|
||||
dest_confirmed = not Confirm.ask(f"Would you like to pick a different location?", default=False)
|
||||
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
@ -90,7 +90,7 @@ def dest_path(dest=None) -> Path:
|
||||
dest = Path(dest).expanduser().resolve()
|
||||
else:
|
||||
dest = Path.cwd().expanduser().resolve()
|
||||
prev_dest = dest.expanduser().resolve()
|
||||
prev_dest = init_path = dest
|
||||
|
||||
dest_confirmed = confirm_install(dest)
|
||||
|
||||
@ -109,9 +109,9 @@ def dest_path(dest=None) -> Path:
|
||||
)
|
||||
|
||||
console.line()
|
||||
print(f"[orange3]Please select the destination directory for the installation:[/] \[{browse_start}]: ")
|
||||
console.print(f"[orange3]Please select the destination directory for the installation:[/] \\[{browse_start}]: ")
|
||||
selected = prompt(
|
||||
f">>> ",
|
||||
">>> ",
|
||||
complete_in_thread=True,
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
@ -134,14 +134,14 @@ def dest_path(dest=None) -> Path:
|
||||
try:
|
||||
dest.mkdir(exist_ok=True, parents=True)
|
||||
return dest
|
||||
except PermissionError as exc:
|
||||
print(
|
||||
except PermissionError:
|
||||
console.print(
|
||||
f"Failed to create directory {dest} due to insufficient permissions",
|
||||
style=Style(color="red"),
|
||||
highlight=True,
|
||||
)
|
||||
except OSError as exc:
|
||||
console.print_exception(exc)
|
||||
except OSError:
|
||||
console.print_exception()
|
||||
|
||||
if Confirm.ask("Would you like to try again?"):
|
||||
dest_path(init_path)
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from logging import Logger
|
||||
import sqlite3
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
@ -48,7 +47,7 @@ def check_internet() -> bool:
|
||||
try:
|
||||
urllib.request.urlopen(host, timeout=1)
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
@ -34,7 +34,7 @@ async def add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ async def remove_image_from_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
|
||||
@ -79,10 +79,10 @@ async def add_images_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
added_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
|
||||
@ -105,8 +105,8 @@ async def remove_images_from_board(
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
@ -37,7 +37,7 @@ async def create_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ async def get_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ async def update_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ async def delete_board(
|
||||
deleted_board_images=deleted_board_images,
|
||||
deleted_images=[],
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete board")
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ async def upload_image(
|
||||
if crop_visible:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except:
|
||||
except Exception:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
@ -73,7 +73,7 @@ async def upload_image(
|
||||
response.headers["Location"] = image_dto.image_url
|
||||
|
||||
return image_dto
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ async def delete_image(
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
@ -97,7 +97,7 @@ async def clear_intermediates() -> int:
|
||||
try:
|
||||
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
|
||||
return count_deleted
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
|
||||
pass
|
||||
|
||||
@ -115,7 +115,7 @@ async def update_image(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ async def get_image_dto(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ async def get_image_metadata(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -183,7 +183,7 @@ async def get_image_full(
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -212,7 +212,7 @@ async def get_image_thumbnail(
|
||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -234,7 +234,7 @@ async def get_image_urls(
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -282,10 +282,10 @@ async def delete_images_from_list(
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
@ -303,10 +303,10 @@ async def star_images_in_list(
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@ -320,8 +320,8 @@ async def unstar_images_in_list(
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
@ -8,7 +8,8 @@ from pydantic.fields import Field
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import BatchProcess, BatchSession, BatchSessionNotFoundException
|
||||
|
||||
from ...invocations import *
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import asyncio
|
||||
import sys
|
||||
from inspect import signature
|
||||
|
||||
import logging
|
||||
@ -17,21 +16,11 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
# This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before
|
||||
# other invokeai initialization messages
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
@ -40,12 +29,17 @@ from .api.routers import sessions, models, images, boards, board_images, app_inf
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
@ -230,13 +224,16 @@ def invoke_api():
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
l = logging.getLogger(logname)
|
||||
l.handlers.clear()
|
||||
log = logging.getLogger(logname)
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
l.addHandler(ch)
|
||||
log.addHandler(ch)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_api()
|
||||
|
@ -145,10 +145,10 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(True)
|
||||
except:
|
||||
except AttributeError:
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
pass
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
|
@ -14,16 +14,8 @@ from pydantic.fields import Field
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before other invokeai initialization messages
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
@ -65,10 +57,15 @@ from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
@ -488,4 +485,7 @@ def invoke_cli():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_cli()
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_cli()
|
||||
|
@ -230,7 +230,7 @@ def InputField(
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
|
@ -5,10 +5,10 @@ from typing import Literal
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Integer Range")
|
||||
|
@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management import ModelPatcher, ModelType
|
||||
from ...backend.model_management.models import ModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
|
@ -29,7 +29,7 @@ from pydantic import BaseModel, Field, validator
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
|
@ -90,7 +90,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
replace_count = (tiles_mask is False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
|
@ -15,7 +15,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
@ -30,7 +30,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -45,12 +45,10 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
|
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -16,7 +16,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
@ -2,14 +2,13 @@
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import ExitStack
|
||||
|
||||
# from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -72,7 +71,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
@ -259,7 +258,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet, ExitStack() as stack:
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
|
@ -38,13 +38,10 @@ from easing_functions import (
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
from anyio import Condition
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
|
@ -1,18 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union, Optional
|
||||
from typing import Optional
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
|
@ -1,15 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import sqlite3
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import (
|
||||
BoardRecord,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
|
||||
@ -228,7 +227,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
@ -239,7 +238,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
|
@ -167,7 +167,7 @@ from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@ -394,7 +394,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
|
||||
precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
@ -415,8 +415,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
# fmt: on
|
||||
@ -438,7 +438,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
@ -457,7 +457,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) != cls
|
||||
or type(cls.singleton_config) is not cls
|
||||
or (kwargs and cls.singleton_init != kwargs)
|
||||
):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
|
@ -9,7 +9,8 @@ import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ..invocations import * # noqa: F401 F403
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -445,7 +446,7 @@ class Graph(BaseModel):
|
||||
node = graph.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) != type(new_node):
|
||||
if type(node) is not type(new_node):
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
@ -632,7 +633,7 @@ class Graph(BaseModel):
|
||||
[
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
]
|
||||
) # Get unique types
|
||||
@ -923,7 +924,7 @@ class GraphExecutionState(BaseModel):
|
||||
None,
|
||||
)
|
||||
|
||||
if next_node_id == None:
|
||||
if next_node_id is None:
|
||||
return None
|
||||
|
||||
# Get all parents of the next node
|
||||
|
@ -179,7 +179,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
if image_name not in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
|
@ -280,7 +280,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
SELECT images.metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
@ -318,7 +318,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
@ -329,7 +329,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
@ -340,7 +340,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@ -379,10 +378,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
self._services.logger.error("Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
self._services.logger.error("Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
@ -395,10 +394,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image records")
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image files")
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
@ -412,10 +411,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
return count
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image records")
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image files")
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
|
@ -8,6 +8,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
@ -29,6 +28,7 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme
|
||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
@ -42,6 +42,11 @@ import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
@ -89,6 +94,8 @@ class InvocationStatsServiceBase(ABC):
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
@ -97,6 +104,8 @@ class InvocationStatsServiceBase(ABC):
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -115,6 +124,9 @@ class NodeStats:
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -133,31 +145,62 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
# {graph_id => NodeLog}
|
||||
self._stats: Dict[str, NodeLog] = {}
|
||||
self._cache_stats: Dict[str, CacheStats] = {}
|
||||
self.ram_used: float = 0.0
|
||||
self.ram_changed: float = 0.0
|
||||
|
||||
class StatsContext:
|
||||
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
invocation: BaseInvocation = None
|
||||
collector: "InvocationStatsServiceBase" = None
|
||||
graph_id: str = None
|
||||
start_time: int = 0
|
||||
ram_used: int = 0
|
||||
model_manager: ModelManagerService = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
collector: "InvocationStatsServiceBase",
|
||||
):
|
||||
"""Initialize statistics for this run."""
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.ram_used = psutil.Process().memory_info().rss
|
||||
if self.model_manager:
|
||||
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
|
||||
def __exit__(self, *args):
|
||||
"""Called on exit from the context."""
|
||||
ram_used = psutil.Process().memory_info().rss
|
||||
self.collector.update_mem_stats(
|
||||
ram_used=ram_used / GIG,
|
||||
ram_changed=(ram_used - self.ram_used) / GIG,
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
self.graph_id,
|
||||
self.invocation.type,
|
||||
time.time() - self.start_time,
|
||||
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type,
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
"""
|
||||
Return a context object that will capture the statistics.
|
||||
@ -166,7 +209,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
"""
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, self)
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
@ -179,13 +223,36 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
except KeyError:
|
||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||
|
||||
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
self.ram_used = ram_used
|
||||
self.ram_changed = ram_changed
|
||||
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Floating point seconds used by node's exection
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
@ -197,7 +264,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
def log_stats(self):
|
||||
"""
|
||||
Send the statistics to the system logger at the info level.
|
||||
Stats will only be printed if when the execution of the graph
|
||||
Stats will only be printed when the execution of the graph
|
||||
is complete.
|
||||
"""
|
||||
completed = set()
|
||||
@ -208,16 +275,30 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
|
||||
total_time = 0
|
||||
logger.info(f"Graph stats: {graph_id}")
|
||||
logger.info("Node Calls Seconds VRAM Used")
|
||||
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
|
||||
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
|
||||
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
|
||||
total_time += stats.time_used
|
||||
|
||||
cache_stats = self._cache_stats[graph_id]
|
||||
hwm = cache_stats.high_watermark / GIG
|
||||
tot = cache_stats.cache_size / GIG
|
||||
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
||||
|
||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||
logger.info(f"RAM used to load models: {loaded:4.2f}G")
|
||||
if torch.cuda.is_available():
|
||||
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
|
||||
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
|
||||
logger.info("RAM cache statistics:")
|
||||
logger.info(f" Model cache hits: {cache_stats.hits}")
|
||||
logger.info(f" Model cache misses: {cache_stats.misses}")
|
||||
logger.info(f" Models cached: {cache_stats.in_cache}")
|
||||
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
|
||||
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
|
||||
|
||||
completed.add(graph_id)
|
||||
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
@ -60,7 +60,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if not name in self.__cache:
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
|
@ -22,6 +22,7 @@ from invokeai.backend.model_management import (
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -276,6 +277,13 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
@ -500,6 +508,12 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
|
@ -86,7 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
|
@ -1,3 +1,4 @@
|
||||
from typing import Union
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
@ -5,7 +6,7 @@ from PIL import Image
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
|
||||
from einops import rearrange
|
||||
from controlnet_aux.util import HWC3, resize_image
|
||||
from controlnet_aux.util import HWC3
|
||||
|
||||
###################################################################
|
||||
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
||||
@ -232,7 +233,8 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
||||
k0 = float(h) / old_h
|
||||
k1 = float(w) / old_w
|
||||
|
||||
safeint = lambda x: int(np.round(x))
|
||||
def safeint(x: Union[int, float]) -> int:
|
||||
return int(np.round(x))
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
if resize_mode == "fill_resize": # OUTER_FIT
|
||||
|
@ -5,7 +5,6 @@ from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from .model_management.models import SilenceWarnings
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo # noqa: F401
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
|
@ -1,14 +1,16 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
from .patchmatch import PatchMatch
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata
|
||||
from .seamless import configure_model_padding
|
||||
from .txt2mask import Txt2Mask
|
||||
from .util import InitImageResizer, make_grid
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
from .txt2mask import Txt2Mask # noqa: F401
|
||||
from .util import InitImageResizer, make_grid # noqa: F401
|
||||
|
||||
|
||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
||||
from PIL import ImageDraw
|
||||
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
|
@ -26,7 +26,7 @@ class PngWriter:
|
||||
dirlist = sorted(os.listdir(self.outdir), reverse=True)
|
||||
# find the first filename that matches our pattern or return 000000.0.png
|
||||
existing_name = next(
|
||||
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
|
||||
(f for f in dirlist if re.match(r"^(\d+)\..*\.png", f)),
|
||||
"0000000.0.png",
|
||||
)
|
||||
basecount = int(existing_name.split(".", 1)[0]) + 1
|
||||
@ -98,11 +98,11 @@ class PromptFormatter:
|
||||
# to do: put model name into the t2i object
|
||||
# switches.append(f'--model{t2i.model_name}')
|
||||
if opt.seamless or t2i.seamless:
|
||||
switches.append(f"--seamless")
|
||||
switches.append("--seamless")
|
||||
if opt.init_img:
|
||||
switches.append(f"-I{opt.init_img}")
|
||||
if opt.fit:
|
||||
switches.append(f"--fit")
|
||||
switches.append("--fit")
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f"-f{opt.strength or t2i.strength}")
|
||||
if opt.gfpgan_strength:
|
||||
|
@ -52,7 +52,6 @@ from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
CenteredButtonPress,
|
||||
FileBox,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
@ -308,7 +307,7 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
first_time = not (config.root_path / "invokeai.yaml").exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
|
@ -116,7 +116,7 @@ class MigrateTo3(object):
|
||||
appropriate location within the destination models directory.
|
||||
"""
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir):
|
||||
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root, d)
|
||||
@ -525,7 +525,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||
except:
|
||||
except Exception:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / "models").replace(dest_models)
|
||||
@ -553,7 +553,7 @@ def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="invokeai-migrate3",
|
||||
description="""
|
||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
||||
|
||||
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
||||
|
@ -12,7 +12,6 @@ from typing import Optional, List, Dict, Callable, Union, Set
|
||||
import requests
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
import onnx
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .models import (
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType # noqa: F401
|
||||
from .model_cache import ModelCache # noqa: F401
|
||||
from .lora import ModelPatcher, ONNXModelPatcher # noqa: F401
|
||||
from .models import ( # noqa: F401
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
@ -12,5 +12,4 @@ from .models import (
|
||||
ModelNotFoundException,
|
||||
DuplicateModelException,
|
||||
)
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
from .lora import ModelPatcher
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod # noqa: F401
|
||||
|
@ -5,21 +5,16 @@ from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import OrtValue
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .models.lora import LoRAModel
|
||||
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
@ -52,7 +47,7 @@ class ModelPatcher:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except:
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
@ -312,7 +307,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
from diffusers import OnnxRuntimeModel
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
@ -341,7 +337,7 @@ class ONNXModelPatcher:
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
|
@ -21,12 +21,12 @@ import os
|
||||
import sys
|
||||
import hashlib
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, types, Optional, Type, Any
|
||||
|
||||
import torch
|
||||
|
||||
import logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
|
||||
@ -41,6 +41,18 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
# {submodel_key => size}
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
@ -115,6 +127,9 @@ class ModelCache(object):
|
||||
self.sha_chunksize = sha_chunksize
|
||||
self.logger = logger
|
||||
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
|
||||
@ -181,13 +196,14 @@ class ModelCache(object):
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(
|
||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||
)
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
|
||||
# this will remove older cached models until
|
||||
# there is sufficient room to load the requested model
|
||||
@ -201,6 +217,17 @@ class ModelCache(object):
|
||||
|
||||
cache_entry = _CacheRecord(self, model, mem_used)
|
||||
self._cached_models[key] = cache_entry
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
|
||||
if self.stats:
|
||||
self.stats.cache_size = self.max_cache_size * GIG
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[key] = max(
|
||||
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
@ -246,7 +273,7 @@ class ModelCache(object):
|
||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
self.cache_entry.unlock()
|
||||
raise
|
||||
|
||||
@ -280,14 +307,14 @@ class ModelCache(object):
|
||||
"""
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
"""
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"Return the current size of the cache, in GB"
|
||||
current_cache_size = sum([m.size for m in self._cached_models.values()])
|
||||
return current_cache_size / GIG
|
||||
"""Return the current size of the cache, in GB."""
|
||||
return self._cache_size() / GIG
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.execution_device.type == "cuda"
|
||||
@ -310,12 +337,15 @@ class ModelCache(object):
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
|
||||
)
|
||||
|
||||
def _cache_size(self) -> int:
|
||||
return sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
def _make_cache_room(self, model_size):
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = sum([m.size for m in self._cached_models.values()])
|
||||
current_size = self._cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
@ -364,6 +394,8 @@ class ModelCache(object):
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
if self.stats:
|
||||
self.stats.cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
@ -419,12 +419,12 @@ class ModelManager(object):
|
||||
base_model_str, model_type_str, model_name = model_key.split("/", 2)
|
||||
try:
|
||||
model_type = ModelType(model_type_str)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(f"Unknown model type: {model_type_str}")
|
||||
|
||||
try:
|
||||
base_model = BaseModelType(base_model_str)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(f"Unknown base model: {base_model_str}")
|
||||
|
||||
return (model_name, base_model, model_type)
|
||||
@ -855,7 +855,7 @@ class ModelManager(object):
|
||||
info.pop("config")
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
|
||||
except:
|
||||
except Exception:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
@ -1042,7 +1042,7 @@ class ModelManager(object):
|
||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||
try:
|
||||
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
installer = ModelInstall(
|
||||
|
@ -217,9 +217,9 @@ class ModelProbe(object):
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
|
||||
|
||||
###################################################3
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
###################################################3
|
||||
# ##################################################3
|
||||
class ProbeBase(object):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
@ -431,7 +431,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
@ -56,7 +56,7 @@ class ModelSearch(ABC):
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import (
|
||||
from .base import ( # noqa: F401
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
@ -118,7 +118,7 @@ def get_model_config_enums():
|
||||
fields = model_config.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("format field not found")
|
||||
|
||||
# model_format: None
|
||||
|
@ -3,27 +3,28 @@ import os
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
from enum import Enum
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
import onnx
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import (
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -171,7 +172,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
fields = value.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
@ -244,7 +245,7 @@ class DiffusersModel(ModelBase):
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||
|
||||
config_data.pop("_ignore_files", None)
|
||||
@ -343,7 +344,7 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
||||
with open(os.path.join(model_path, file), "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# calculate files size if there is no index file
|
||||
@ -440,7 +441,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
except:
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
@ -452,11 +453,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
return checkpoint
|
||||
|
||||
|
||||
import warnings
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
@ -639,7 +635,7 @@ class IAIOnnxRuntimeModel:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
output_names = self.session.get_outputs()
|
||||
# output_names = self.session.get_outputs()
|
||||
# for k in inputs:
|
||||
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
# for name in output_names:
|
||||
|
@ -43,7 +43,7 @@ class ControlNetModel(ModelBase):
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||
|
||||
model_class_name = config.get("_class_name", None)
|
||||
@ -53,7 +53,7 @@ class ControlNetModel(ModelBase):
|
||||
try:
|
||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid ControlNet model!")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
@ -78,7 +78,7 @@ class ControlNetModel(ModelBase):
|
||||
variant=variant,
|
||||
)
|
||||
break
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
if not model:
|
||||
raise ModelNotFoundException()
|
||||
|
@ -330,5 +330,5 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
config_path = config_path.relative_to(app_config.root_path)
|
||||
return str(config_path)
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
@ -1,25 +1,17 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
|
||||
from diffusers import OnnxRuntimeModel
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
OnnxRuntimeModel,
|
||||
IAIOnnxRuntimeModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class StableDiffusionOnnxModelFormat(str, Enum):
|
||||
|
@ -44,14 +44,14 @@ class VaeModel(ModelBase):
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||
|
||||
try:
|
||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
|
@ -1,11 +1,15 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .diffusers_pipeline import (
|
||||
from .diffusers_pipeline import ( # noqa: F401
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
||||
PostprocessingSettings,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
@ -2,10 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, Union
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .shared_invokeai_diffusion import (
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import ( # noqa: F401
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
BasicConditioningInfo,
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
import enum
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
import diffusers
|
||||
@ -12,6 +13,11 @@ import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -522,14 +528,6 @@ class AttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -5,8 +5,6 @@ import torch
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from .cross_attention_control import CrossAttentionType, get_cross_attention_modules
|
||||
|
||||
|
||||
class AttentionMapSaver:
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
|
@ -3,15 +3,12 @@ from __future__ import annotations
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .cross_attention_control import (
|
||||
@ -240,6 +237,7 @@ class InvokeAIDiffuserComponent:
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
@ -578,7 +576,7 @@ class InvokeAIDiffuserComponent:
|
||||
latents.to(device="cpu")
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None
|
||||
h_symmetry_time_pct is not None
|
||||
and self.last_percent_through < h_symmetry_time_pct
|
||||
and percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
@ -594,7 +592,7 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None
|
||||
v_symmetry_time_pct is not None
|
||||
and self.last_percent_through < v_symmetry_time_pct
|
||||
and percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
|
@ -1,6 +1,6 @@
|
||||
from ldm.modules.image_degradation.bsrgan import (
|
||||
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
|
||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
||||
)
|
||||
from ldm.modules.image_degradation.bsrgan_light import (
|
||||
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
|
||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
||||
)
|
||||
|
@ -573,14 +573,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
@ -777,7 +778,7 @@ if __name__ == "__main__":
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
# print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(
|
||||
util.single2uint(img_lq),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
@ -788,5 +789,6 @@ if __name__ == "__main__":
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
||||
|
@ -577,14 +577,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
|
@ -8,8 +8,6 @@ import numpy as np
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
@ -50,6 +48,8 @@ def get_timestamp():
|
||||
|
||||
|
||||
def imshow(x, title=None, cbar=False, figsize=None):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
||||
if title:
|
||||
@ -60,6 +60,8 @@ def imshow(x, title=None, cbar=False, figsize=None):
|
||||
|
||||
|
||||
def surf(Z, cmap="rainbow", figsize=None):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
ax3 = plt.axes(projection="3d")
|
||||
|
||||
@ -89,7 +91,7 @@ def get_image_paths(dataroot):
|
||||
def _get_paths_from_images(path):
|
||||
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
|
||||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||
for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname):
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
|
@ -1 +1 @@
|
||||
from .schedulers import SCHEDULER_MAP
|
||||
from .schedulers import SCHEDULER_MAP # noqa: F401
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.training
|
||||
"""
|
||||
from .textual_inversion_training import do_textual_inversion_training, parse_args
|
||||
from .textual_inversion_training import do_textual_inversion_training, parse_args # noqa: F401
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
from .devices import (
|
||||
from .devices import ( # noqa: F401
|
||||
CPU_DEVICE,
|
||||
CUDA_DEVICE,
|
||||
MPS_DEVICE,
|
||||
@ -10,5 +10,5 @@ from .devices import (
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
)
|
||||
from .log import write_log
|
||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir
|
||||
from .log import write_log # noqa: F401
|
||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401
|
||||
|
@ -4,8 +4,15 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlnetMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
from diffusers.models.embeddings import (
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
@ -18,10 +25,16 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
import diffusers
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# TODO: create PR to diffusers
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
logger = InvokeAILogger.getLogger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@ -52,12 +65,25 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
@ -90,7 +116,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
"DownBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@ -98,10 +124,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
@ -109,6 +140,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -136,6 +168,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
@ -145,16 +180,43 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
@ -178,6 +240,29 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
@ -212,6 +297,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
@ -248,6 +334,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
@ -277,7 +364,22 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
@ -463,6 +565,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
guess_mode: bool = False,
|
||||
@ -486,7 +589,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||
@ -549,6 +654,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
@ -560,11 +666,34 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if "addition_embed_type" in self.config:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
|
@ -27,8 +27,8 @@ def write_log_message(results, output_cntr):
|
||||
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
if len(log_lines) > 1:
|
||||
subcntr = 1
|
||||
for l in log_lines:
|
||||
print(f"[{output_cntr}.{subcntr}] {l}", end="")
|
||||
for ll in log_lines:
|
||||
print(f"[{output_cntr}.{subcntr}] {ll}", end="")
|
||||
subcntr += 1
|
||||
else:
|
||||
print(f"[{output_cntr}] {log_lines[0]}", end="")
|
||||
|
@ -182,13 +182,13 @@ import urllib.parse
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
try:
|
||||
import syslog
|
||||
|
||||
SYSLOG_AVAILABLE = True
|
||||
except:
|
||||
except ImportError:
|
||||
SYSLOG_AVAILABLE = False
|
||||
|
||||
|
||||
@ -417,7 +417,7 @@ class InvokeAILogger(object):
|
||||
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
|
||||
else:
|
||||
syslog_args["address"] = arg_name
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError(f"{args} is not a value argument list for syslog logging")
|
||||
return logging.handlers.SysLogHandler(**syslog_args)
|
||||
|
||||
|
@ -191,7 +191,7 @@ class ChunkedSlicedAttnProcessor:
|
||||
assert value.shape[0] == 1
|
||||
assert hidden_states.shape[0] == 1
|
||||
|
||||
dtype = query.dtype
|
||||
# dtype = query.dtype
|
||||
if attn.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
@ -84,7 +84,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
@ -234,16 +234,17 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
|
||||
dot = lambda grad, shift: (
|
||||
torch.stack(
|
||||
(
|
||||
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||
grid[: shape[0], : shape[1], 1] + shift[1],
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
def dot(grad, shift):
|
||||
return (
|
||||
torch.stack(
|
||||
(
|
||||
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||
grid[: shape[0], : shape[1], 1] + shift[1],
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
||||
@ -287,7 +288,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
if dest.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except:
|
||||
except AttributeError:
|
||||
file_name = os.path.basename(url)
|
||||
dest = dest / file_name
|
||||
else:
|
||||
@ -342,7 +343,7 @@ def url_attachment_name(url: str) -> dict:
|
||||
resp = requests.get(url, stream=True)
|
||||
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||
return match.group(1)
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.CLI
|
||||
"""
|
||||
from .CLI import main as invokeai_command_line_interface
|
||||
from .CLI import main as invokeai_command_line_interface # noqa: F401
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
"""
|
||||
from ...backend.install.invokeai_configure import main as invokeai_configure
|
||||
from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401
|
||||
|
@ -80,7 +80,7 @@ def welcome(versions: dict):
|
||||
def get_extras():
|
||||
extras = ""
|
||||
try:
|
||||
dist = pkg_resources.get_distribution("xformers")
|
||||
_ = pkg_resources.get_distribution("xformers")
|
||||
extras = "[xformers]"
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
@ -90,7 +90,7 @@ def get_extras():
|
||||
def main():
|
||||
versions = get_versions()
|
||||
if invokeai_is_running():
|
||||
print(f":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
||||
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
||||
input("Press any key to continue...")
|
||||
return
|
||||
|
||||
@ -122,9 +122,9 @@ def main():
|
||||
print("")
|
||||
print("")
|
||||
if os.system(cmd) == 0:
|
||||
print(f":heavy_check_mark: Upgrade successful")
|
||||
print(":heavy_check_mark: Upgrade successful")
|
||||
else:
|
||||
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||
print(":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -251,7 +251,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
widgets = dict()
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and not x in exclude]
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
@ -357,14 +357,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
try:
|
||||
v.hidden = True
|
||||
v.editable = False
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
for k, v in widgets[selected_tab].items():
|
||||
try:
|
||||
v.hidden = False
|
||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||
v.editable = True
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
self.__class__.current_tab = selected_tab # for persistence
|
||||
self.display()
|
||||
@ -541,7 +541,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.ti_models,
|
||||
]
|
||||
for section in ui_sections:
|
||||
if not "models_selected" in section:
|
||||
if "models_selected" not in section:
|
||||
continue
|
||||
selected = set([section["models"][x] for x in section["models_selected"].value])
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
@ -637,7 +637,7 @@ def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPre
|
||||
return None
|
||||
else:
|
||||
return response
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@ -673,8 +673,7 @@ def process_and_execute(
|
||||
def select_and_download_models(opt: Namespace):
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
helper = lambda x: ask_user_for_prediction_type(x)
|
||||
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
||||
if opt.list_models:
|
||||
installer.list_models(opt.list_models)
|
||||
elif opt.add or opt.delete:
|
||||
|
@ -102,8 +102,8 @@ def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
@ -167,8 +167,8 @@ class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401
|
||||
|
@ -9,19 +9,15 @@ import curses
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import npyscreen
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
ModelManager,
|
||||
ModelType,
|
||||
BaseModelType,
|
||||
@ -318,7 +314,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self, base_model: BaseModelType = None) -> List[str]:
|
||||
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
|
||||
model_names = [
|
||||
info["model_name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.training
|
||||
"""
|
||||
from .textual_inversion import main as invokeai_textual_inversion
|
||||
from .textual_inversion import main as invokeai_textual_inversion # noqa: F401
|
||||
|
@ -59,7 +59,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
@ -377,7 +377,7 @@ def previous_args() -> dict:
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except:
|
||||
except Exception:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
@ -506,10 +506,14 @@
|
||||
"maskAdjustmentsHeader": "Mask Adjustments",
|
||||
"maskBlur": "Mask Blur",
|
||||
"maskBlurMethod": "Mask Blur Method",
|
||||
"seamPaintingHeader": "Seam Painting",
|
||||
"seamSize": "Seam Size",
|
||||
"seamBlur": "Seam Blur",
|
||||
"seamStrength": "Seam Strength",
|
||||
"seamSteps": "Seam Steps",
|
||||
"seamStrength": "Seam Strength",
|
||||
"seamThreshold": "Seam Threshold",
|
||||
"seamLowThreshold": "Low",
|
||||
"seamHighThreshold": "High",
|
||||
"scaleBeforeProcessing": "Scale Before Processing",
|
||||
"scaledWidth": "Scaled W",
|
||||
"scaledHeight": "Scaled H",
|
||||
|
@ -121,7 +121,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTOs, imagesUsage } = action.payload;
|
||||
|
||||
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
|
||||
if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
|
||||
// handle singles in separate listener
|
||||
return;
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ import {
|
||||
MAIN_MODEL_LOADER,
|
||||
MASK_BLUR,
|
||||
MASK_COMBINE,
|
||||
MASK_EDGE,
|
||||
MASK_FROM_ALPHA,
|
||||
MASK_RESIZE_DOWN,
|
||||
MASK_RESIZE_UP,
|
||||
@ -40,6 +41,10 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
SEAM_FIX_DENOISE_LATENTS,
|
||||
SEAM_MASK_COMBINE,
|
||||
SEAM_MASK_RESIZE_DOWN,
|
||||
SEAM_MASK_RESIZE_UP,
|
||||
} from './constants';
|
||||
|
||||
/**
|
||||
@ -67,6 +72,12 @@ export const buildCanvasOutpaintGraph = (
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
maskBlurMethod,
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamSteps,
|
||||
seamStrength,
|
||||
seamLowThreshold,
|
||||
seamHighThreshold,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
clipSkip,
|
||||
@ -130,6 +141,11 @@ export const buildCanvasOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
mask2: canvasMaskImage,
|
||||
},
|
||||
[SEAM_MASK_COMBINE]: {
|
||||
type: 'mask_combine',
|
||||
id: MASK_COMBINE,
|
||||
is_intermediate: true,
|
||||
},
|
||||
[MASK_BLUR]: {
|
||||
type: 'img_blur',
|
||||
id: MASK_BLUR,
|
||||
@ -165,6 +181,25 @@ export const buildCanvasOutpaintGraph = (
|
||||
denoising_start: 1 - strength,
|
||||
denoising_end: 1,
|
||||
},
|
||||
[MASK_EDGE]: {
|
||||
type: 'mask_edge',
|
||||
id: MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_size: seamSize,
|
||||
edge_blur: seamBlur,
|
||||
low_threshold: seamLowThreshold,
|
||||
high_threshold: seamHighThreshold,
|
||||
},
|
||||
[SEAM_FIX_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: SEAM_FIX_DENOISE_LATENTS,
|
||||
is_intermediate: true,
|
||||
steps: seamSteps,
|
||||
cfg_scale: cfg_scale,
|
||||
scheduler: scheduler,
|
||||
denoising_start: 1 - seamStrength,
|
||||
denoising_end: 1,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
@ -333,12 +368,63 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
// Decode the result from Inpaint
|
||||
// Seam Paint
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
// Decode the result from Inpaint
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
@ -348,7 +434,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
};
|
||||
|
||||
// Add Infill Nodes
|
||||
|
||||
if (infillMethod === 'patchmatch') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_patchmatch',
|
||||
@ -378,6 +463,13 @@ export const buildCanvasOutpaintGraph = (
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_UP] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_UP,
|
||||
is_intermediate: true,
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: INPAINT_IMAGE_RESIZE_DOWN,
|
||||
@ -399,6 +491,13 @@ export const buildCanvasOutpaintGraph = (
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_DOWN,
|
||||
is_intermediate: true,
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
|
||||
graph.nodes[NOISE] = {
|
||||
...(graph.nodes[NOISE] as NoiseInvocation),
|
||||
@ -440,6 +539,57 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
@ -453,7 +603,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -461,6 +611,16 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
@ -494,7 +654,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -525,7 +685,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -553,7 +713,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
};
|
||||
graph.nodes[MASK_BLUR] = {
|
||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -568,6 +727,47 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Color Correct The Inpainted Result
|
||||
{
|
||||
source: {
|
||||
@ -591,7 +791,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -622,7 +822,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
|
@ -29,6 +29,7 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
MASK_BLUR,
|
||||
MASK_COMBINE,
|
||||
MASK_EDGE,
|
||||
MASK_FROM_ALPHA,
|
||||
MASK_RESIZE_DOWN,
|
||||
MASK_RESIZE_UP,
|
||||
@ -40,6 +41,10 @@ import {
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
SEAM_FIX_DENOISE_LATENTS,
|
||||
SEAM_MASK_COMBINE,
|
||||
SEAM_MASK_RESIZE_DOWN,
|
||||
SEAM_MASK_RESIZE_UP,
|
||||
} from './constants';
|
||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||
|
||||
@ -67,6 +72,12 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
maskBlurMethod,
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamSteps,
|
||||
seamStrength,
|
||||
seamLowThreshold,
|
||||
seamHighThreshold,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
} = state.generation;
|
||||
@ -133,6 +144,11 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
mask2: canvasMaskImage,
|
||||
},
|
||||
[SEAM_MASK_COMBINE]: {
|
||||
type: 'mask_combine',
|
||||
id: MASK_COMBINE,
|
||||
is_intermediate: true,
|
||||
},
|
||||
[MASK_BLUR]: {
|
||||
type: 'img_blur',
|
||||
id: MASK_BLUR,
|
||||
@ -170,6 +186,25 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
: 1 - strength,
|
||||
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||
},
|
||||
[MASK_EDGE]: {
|
||||
type: 'mask_edge',
|
||||
id: MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_size: seamSize,
|
||||
edge_blur: seamBlur,
|
||||
low_threshold: seamLowThreshold,
|
||||
high_threshold: seamHighThreshold,
|
||||
},
|
||||
[SEAM_FIX_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: SEAM_FIX_DENOISE_LATENTS,
|
||||
is_intermediate: true,
|
||||
steps: seamSteps,
|
||||
cfg_scale: cfg_scale,
|
||||
scheduler: scheduler,
|
||||
denoising_start: 1 - seamStrength,
|
||||
denoising_end: 1,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
@ -347,12 +382,63 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
// Decode inpainted latents to image
|
||||
// Seam Paint
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
// Decode inpainted latents to image
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
@ -392,6 +478,13 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_UP] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_UP,
|
||||
is_intermediate: true,
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: INPAINT_IMAGE_RESIZE_DOWN,
|
||||
@ -413,6 +506,13 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_DOWN,
|
||||
is_intermediate: true,
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
|
||||
graph.nodes[NOISE] = {
|
||||
...(graph.nodes[NOISE] as NoiseInvocation),
|
||||
@ -454,6 +554,57 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
@ -467,7 +618,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -475,6 +626,16 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
@ -508,7 +669,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -539,7 +700,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -567,7 +728,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
};
|
||||
graph.nodes[MASK_BLUR] = {
|
||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -582,6 +742,47 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Color Correct The Inpainted Result
|
||||
{
|
||||
source: {
|
||||
@ -605,7 +806,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -636,7 +837,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -669,7 +870,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (shouldUseSDXLRefiner) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addSDXLRefinerToGraph(state, graph, SEAM_FIX_DENOISE_LATENTS);
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
|
@ -18,8 +18,6 @@ export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
export const RESIZE = 'resize_image';
|
||||
export const CANVAS_OUTPUT = 'canvas_output';
|
||||
export const INPAINT = 'inpaint';
|
||||
export const INPAINT_SEAM_FIX = 'inpaint_seam_fix';
|
||||
export const INPAINT_IMAGE = 'inpaint_image';
|
||||
export const SCALED_INPAINT_IMAGE = 'scaled_inpaint_image';
|
||||
export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up';
|
||||
@ -27,10 +25,14 @@ export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down';
|
||||
export const INPAINT_INFILL = 'inpaint_infill';
|
||||
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
|
||||
export const INPAINT_FINAL_IMAGE = 'inpaint_final_image';
|
||||
export const SEAM_FIX_DENOISE_LATENTS = 'seam_fix_denoise_latents';
|
||||
export const MASK_FROM_ALPHA = 'tomask';
|
||||
export const MASK_EDGE = 'mask_edge';
|
||||
export const MASK_BLUR = 'mask_blur';
|
||||
export const MASK_COMBINE = 'mask_combine';
|
||||
export const SEAM_MASK_COMBINE = 'seam_mask_combine';
|
||||
export const SEAM_MASK_RESIZE_UP = 'seam_mask_resize_up';
|
||||
export const SEAM_MASK_RESIZE_DOWN = 'seam_mask_resize_down';
|
||||
export const MASK_RESIZE_UP = 'mask_resize_up';
|
||||
export const MASK_RESIZE_DOWN = 'mask_resize_down';
|
||||
export const COLOR_CORRECT = 'color_correct';
|
||||
|
@ -0,0 +1,36 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamBlur } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamBlur = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamBlur = useAppSelector(
|
||||
(state: RootState) => state.generation.seamBlur
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamBlur')}
|
||||
min={0}
|
||||
max={64}
|
||||
step={8}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamBlur}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamBlur(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamBlur(8));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamBlur);
|
@ -0,0 +1,27 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ParamSeamBlur from './ParamSeamBlur';
|
||||
import ParamSeamSize from './ParamSeamSize';
|
||||
import ParamSeamSteps from './ParamSeamSteps';
|
||||
import ParamSeamStrength from './ParamSeamStrength';
|
||||
import ParamSeamThreshold from './ParamSeamThreshold';
|
||||
|
||||
const ParamSeamPaintingCollapse = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAICollapse label={t('parameters.seamPaintingHeader')}>
|
||||
<Flex sx={{ flexDirection: 'column', gap: 2, paddingBottom: 2 }}>
|
||||
<ParamSeamSize />
|
||||
<ParamSeamBlur />
|
||||
<ParamSeamSteps />
|
||||
<ParamSeamStrength />
|
||||
<ParamSeamThreshold />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamPaintingCollapse);
|
@ -0,0 +1,36 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamSize } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamSize = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamSize = useAppSelector(
|
||||
(state: RootState) => state.generation.seamSize
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamSize')}
|
||||
min={0}
|
||||
max={128}
|
||||
step={8}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamSize}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSize(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamSize(16));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamSize);
|
@ -0,0 +1,36 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamSteps } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamSteps = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamSteps = useAppSelector(
|
||||
(state: RootState) => state.generation.seamSteps
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamSteps')}
|
||||
min={0}
|
||||
max={100}
|
||||
step={1}
|
||||
sliderNumberInputProps={{ max: 999 }}
|
||||
value={seamSteps}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSteps(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamSteps(20));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamSteps);
|
@ -0,0 +1,36 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamStrength } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamStrength = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamStrength = useAppSelector(
|
||||
(state: RootState) => state.generation.seamStrength
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamStrength')}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
sliderNumberInputProps={{ max: 999 }}
|
||||
value={seamStrength}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamStrength(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamStrength(0.7));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamStrength);
|
@ -0,0 +1,121 @@
|
||||
import {
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
RangeSlider,
|
||||
RangeSliderFilledTrack,
|
||||
RangeSliderMark,
|
||||
RangeSliderThumb,
|
||||
RangeSliderTrack,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import {
|
||||
setSeamHighThreshold,
|
||||
setSeamLowThreshold,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
|
||||
const ParamSeamThreshold = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamLowThreshold = useAppSelector(
|
||||
(state: RootState) => state.generation.seamLowThreshold
|
||||
);
|
||||
|
||||
const seamHighThreshold = useAppSelector(
|
||||
(state: RootState) => state.generation.seamHighThreshold
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleSeamThresholdChange = useCallback(
|
||||
(v: number[]) => {
|
||||
dispatch(setSeamLowThreshold(v[0] as number));
|
||||
dispatch(setSeamHighThreshold(v[1] as number));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSeamThresholdReset = () => {
|
||||
dispatch(setSeamLowThreshold(100));
|
||||
dispatch(setSeamHighThreshold(200));
|
||||
};
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.seamThreshold')}</FormLabel>
|
||||
<HStack w="100%" gap={4} mt={-2}>
|
||||
<RangeSlider
|
||||
aria-label={[
|
||||
t('parameters.seamLowThreshold'),
|
||||
t('parameters.seamHighThreshold'),
|
||||
]}
|
||||
value={[seamLowThreshold, seamHighThreshold]}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
minStepsBetweenThumbs={1}
|
||||
onChange={handleSeamThresholdChange}
|
||||
>
|
||||
<RangeSliderTrack>
|
||||
<RangeSliderFilledTrack />
|
||||
</RangeSliderTrack>
|
||||
<Tooltip label={seamLowThreshold} placement="top" hasArrow>
|
||||
<RangeSliderThumb index={0} />
|
||||
</Tooltip>
|
||||
<Tooltip label={seamHighThreshold} placement="top" hasArrow>
|
||||
<RangeSliderThumb index={1} />
|
||||
</Tooltip>
|
||||
<RangeSliderMark
|
||||
value={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
}}
|
||||
>
|
||||
0
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.392}
|
||||
sx={{
|
||||
insetInlineStart: '38.4% !important',
|
||||
transform: 'translateX(-38.4%)',
|
||||
}}
|
||||
>
|
||||
100
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.784}
|
||||
sx={{
|
||||
insetInlineStart: '79.8% !important',
|
||||
transform: 'translateX(-79.8%)',
|
||||
}}
|
||||
>
|
||||
200
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={1}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
}}
|
||||
>
|
||||
255
|
||||
</RangeSliderMark>
|
||||
</RangeSlider>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label={t('accessibility.reset')}
|
||||
tooltip={t('accessibility.reset')}
|
||||
icon={<BiReset />}
|
||||
onClick={handleSeamThresholdReset}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamThreshold);
|
@ -37,6 +37,12 @@ export interface GenerationState {
|
||||
scheduler: SchedulerParam;
|
||||
maskBlur: number;
|
||||
maskBlurMethod: MaskBlurMethodParam;
|
||||
seamSize: number;
|
||||
seamBlur: number;
|
||||
seamSteps: number;
|
||||
seamStrength: StrengthParam;
|
||||
seamLowThreshold: number;
|
||||
seamHighThreshold: number;
|
||||
seed: SeedParam;
|
||||
seedWeights: string;
|
||||
shouldFitToWidthHeight: boolean;
|
||||
@ -74,6 +80,12 @@ export const initialGenerationState: GenerationState = {
|
||||
scheduler: 'euler',
|
||||
maskBlur: 16,
|
||||
maskBlurMethod: 'box',
|
||||
seamSize: 16,
|
||||
seamBlur: 8,
|
||||
seamSteps: 20,
|
||||
seamStrength: 0.7,
|
||||
seamLowThreshold: 100,
|
||||
seamHighThreshold: 200,
|
||||
seed: 0,
|
||||
seedWeights: '',
|
||||
shouldFitToWidthHeight: true,
|
||||
@ -200,6 +212,24 @@ export const generationSlice = createSlice({
|
||||
setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => {
|
||||
state.maskBlurMethod = action.payload;
|
||||
},
|
||||
setSeamSize: (state, action: PayloadAction<number>) => {
|
||||
state.seamSize = action.payload;
|
||||
},
|
||||
setSeamBlur: (state, action: PayloadAction<number>) => {
|
||||
state.seamBlur = action.payload;
|
||||
},
|
||||
setSeamSteps: (state, action: PayloadAction<number>) => {
|
||||
state.seamSteps = action.payload;
|
||||
},
|
||||
setSeamStrength: (state, action: PayloadAction<number>) => {
|
||||
state.seamStrength = action.payload;
|
||||
},
|
||||
setSeamLowThreshold: (state, action: PayloadAction<number>) => {
|
||||
state.seamLowThreshold = action.payload;
|
||||
},
|
||||
setSeamHighThreshold: (state, action: PayloadAction<number>) => {
|
||||
state.seamHighThreshold = action.payload;
|
||||
},
|
||||
setTileSize: (state, action: PayloadAction<number>) => {
|
||||
state.tileSize = action.payload;
|
||||
},
|
||||
@ -306,6 +336,12 @@ export const {
|
||||
setScheduler,
|
||||
setMaskBlur,
|
||||
setMaskBlurMethod,
|
||||
setSeamSize,
|
||||
setSeamBlur,
|
||||
setSeamSteps,
|
||||
setSeamStrength,
|
||||
setSeamLowThreshold,
|
||||
setSeamHighThreshold,
|
||||
setSeed,
|
||||
setSeedWeights,
|
||||
setShouldFitToWidthHeight,
|
||||
|
@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
|
||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
||||
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
|
||||
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
@ -22,6 +23,7 @@ export default function SDXLUnifiedCanvasTabParameters() {
|
||||
<ParamNoiseCollapse />
|
||||
<ParamMaskAdjustmentCollapse />
|
||||
<ParamInfillAndScalingCollapse />
|
||||
<ParamSeamPaintingCollapse />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import { InvokeLogLevel } from 'app/logging/logger';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { t } from 'i18next';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { startCase, upperFirst } from 'lodash-es';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import {
|
||||
isAnySessionRejected,
|
||||
@ -26,6 +26,7 @@ import {
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { makeToast } from '../util/makeToast';
|
||||
import { LANGUAGES } from './constants';
|
||||
import { zPydanticValidationError } from './zodSchemas';
|
||||
|
||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||
|
||||
@ -361,9 +362,24 @@ export const systemSlice = createSlice({
|
||||
state.progressImage = null;
|
||||
|
||||
let errorDescription = undefined;
|
||||
const duration = 5000;
|
||||
|
||||
if (action.payload?.status === 422) {
|
||||
errorDescription = 'Validation Error';
|
||||
const result = zPydanticValidationError.safeParse(action.payload);
|
||||
if (result.success) {
|
||||
result.data.error.detail.map((e) => {
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: upperFirst(e.msg),
|
||||
status: 'error',
|
||||
description: `Path:
|
||||
${e.loc.slice(3).join('.')}`,
|
||||
duration,
|
||||
})
|
||||
);
|
||||
});
|
||||
return;
|
||||
}
|
||||
} else if (action.payload?.error) {
|
||||
errorDescription = action.payload?.error as string;
|
||||
}
|
||||
@ -373,6 +389,7 @@ export const systemSlice = createSlice({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description: errorDescription,
|
||||
duration,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
@ -0,0 +1,14 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
export const zPydanticValidationError = z.object({
|
||||
status: z.literal(422),
|
||||
error: z.object({
|
||||
detail: z.array(
|
||||
z.object({
|
||||
loc: z.array(z.string()),
|
||||
msg: z.string(),
|
||||
type: z.string(),
|
||||
})
|
||||
),
|
||||
}),
|
||||
});
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user