Merge branch 'main' into fix/inpaint_gen

This commit is contained in:
psychedelicious 2023-08-18 15:57:48 +10:00
commit 3c43594c26
106 changed files with 444 additions and 417 deletions

View File

@ -1,6 +1,6 @@
name: style checks name: style checks
# just formatting for now # just formatting and flake8 for now
# TODO: add isort and flake8 later # TODO: add isort later
on: on:
pull_request: pull_request:
@ -20,8 +20,8 @@ jobs:
- name: Install dependencies with pip - name: Install dependencies with pip
run: | run: |
pip install black pip install black flake8 Flake8-pyproject
# - run: isort --check-only . # - run: isort --check-only .
- run: black --check . - run: black --check .
# - run: flake8 - run: flake8

View File

@ -8,3 +8,10 @@ repos:
language: system language: system
entry: black entry: black
types: [python] types: [python]
- id: flake8
name: flake8
stages: [commit]
language: system
entry: flake8
types: [python]

View File

@ -407,7 +407,7 @@ def get_pip_from_venv(venv_path: Path) -> str:
:rtype: str :rtype: str
""" """
pip = "Scripts\pip.exe" if OS == "Windows" else "bin/pip" pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
return str(venv_path.expanduser().resolve() / pip) return str(venv_path.expanduser().resolve() / pip)

View File

@ -49,7 +49,7 @@ if __name__ == "__main__":
try: try:
inst.install(**args.__dict__) inst.install(**args.__dict__)
except KeyboardInterrupt as exc: except KeyboardInterrupt:
print("\n") print("\n")
print("Ctrl-C pressed. Aborting.") print("Ctrl-C pressed. Aborting.")
print("Come back soon!") print("Come back soon!")

View File

@ -70,7 +70,7 @@ def confirm_install(dest: Path) -> bool:
) )
else: else:
print(f"InvokeAI will be installed in {dest}") print(f"InvokeAI will be installed in {dest}")
dest_confirmed = not Confirm.ask(f"Would you like to pick a different location?", default=False) dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
console.line() console.line()
return dest_confirmed return dest_confirmed
@ -90,7 +90,7 @@ def dest_path(dest=None) -> Path:
dest = Path(dest).expanduser().resolve() dest = Path(dest).expanduser().resolve()
else: else:
dest = Path.cwd().expanduser().resolve() dest = Path.cwd().expanduser().resolve()
prev_dest = dest.expanduser().resolve() prev_dest = init_path = dest
dest_confirmed = confirm_install(dest) dest_confirmed = confirm_install(dest)
@ -109,9 +109,9 @@ def dest_path(dest=None) -> Path:
) )
console.line() console.line()
print(f"[orange3]Please select the destination directory for the installation:[/] \[{browse_start}]: ") console.print(f"[orange3]Please select the destination directory for the installation:[/] \\[{browse_start}]: ")
selected = prompt( selected = prompt(
f">>> ", ">>> ",
complete_in_thread=True, complete_in_thread=True,
completer=path_completer, completer=path_completer,
default=str(browse_start) + os.sep, default=str(browse_start) + os.sep,
@ -134,14 +134,14 @@ def dest_path(dest=None) -> Path:
try: try:
dest.mkdir(exist_ok=True, parents=True) dest.mkdir(exist_ok=True, parents=True)
return dest return dest
except PermissionError as exc: except PermissionError:
print( console.print(
f"Failed to create directory {dest} due to insufficient permissions", f"Failed to create directory {dest} due to insufficient permissions",
style=Style(color="red"), style=Style(color="red"),
highlight=True, highlight=True,
) )
except OSError as exc: except OSError:
console.print_exception(exc) console.print_exception()
if Confirm.ask("Would you like to try again?"): if Confirm.ask("Would you like to try again?"):
dest_path(init_path) dest_path(init_path)

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from logging import Logger from logging import Logger
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage, SqliteBoardImageRecordStorage,
@ -45,7 +44,7 @@ def check_internet() -> bool:
try: try:
urllib.request.urlopen(host, timeout=1) urllib.request.urlopen(host, timeout=1)
return True return True
except: except Exception:
return False return False

View File

@ -34,7 +34,7 @@ async def add_image_to_board(
board_id=board_id, image_name=image_name board_id=board_id, image_name=image_name
) )
return result return result
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to add image to board") raise HTTPException(status_code=500, detail="Failed to add image to board")
@ -53,7 +53,7 @@ async def remove_image_from_board(
try: try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
return result return result
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board") raise HTTPException(status_code=500, detail="Failed to remove image from board")
@ -79,10 +79,10 @@ async def add_images_to_board(
board_id=board_id, image_name=image_name board_id=board_id, image_name=image_name
) )
added_image_names.append(image_name) added_image_names.append(image_name)
except: except Exception:
pass pass
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names) return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board") raise HTTPException(status_code=500, detail="Failed to add images to board")
@ -105,8 +105,8 @@ async def remove_images_from_board(
try: try:
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_image_names.append(image_name) removed_image_names.append(image_name)
except: except Exception:
pass pass
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names) return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board") raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@ -37,7 +37,7 @@ async def create_board(
try: try:
result = ApiDependencies.invoker.services.boards.create(board_name=board_name) result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
return result return result
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to create board") raise HTTPException(status_code=500, detail="Failed to create board")
@ -50,7 +50,7 @@ async def get_board(
try: try:
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
return result return result
except Exception as e: except Exception:
raise HTTPException(status_code=404, detail="Board not found") raise HTTPException(status_code=404, detail="Board not found")
@ -73,7 +73,7 @@ async def update_board(
try: try:
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes) result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
return result return result
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to update board") raise HTTPException(status_code=500, detail="Failed to update board")
@ -105,7 +105,7 @@ async def delete_board(
deleted_board_images=deleted_board_images, deleted_board_images=deleted_board_images,
deleted_images=[], deleted_images=[],
) )
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to delete board") raise HTTPException(status_code=500, detail="Failed to delete board")

View File

@ -55,7 +55,7 @@ async def upload_image(
if crop_visible: if crop_visible:
bbox = pil_image.getbbox() bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox) pil_image = pil_image.crop(bbox)
except: except Exception:
# Error opening the image # Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image") raise HTTPException(status_code=415, detail="Failed to read image")
@ -73,7 +73,7 @@ async def upload_image(
response.headers["Location"] = image_dto.image_url response.headers["Location"] = image_dto.image_url
return image_dto return image_dto
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")
@ -85,7 +85,7 @@ async def delete_image(
try: try:
ApiDependencies.invoker.services.images.delete(image_name) ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e: except Exception:
# TODO: Does this need any exception handling at all? # TODO: Does this need any exception handling at all?
pass pass
@ -97,7 +97,7 @@ async def clear_intermediates() -> int:
try: try:
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates() count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
return count_deleted return count_deleted
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to clear intermediates") raise HTTPException(status_code=500, detail="Failed to clear intermediates")
pass pass
@ -115,7 +115,7 @@ async def update_image(
try: try:
return ApiDependencies.invoker.services.images.update(image_name, image_changes) return ApiDependencies.invoker.services.images.update(image_name, image_changes)
except Exception as e: except Exception:
raise HTTPException(status_code=400, detail="Failed to update image") raise HTTPException(status_code=400, detail="Failed to update image")
@ -131,7 +131,7 @@ async def get_image_dto(
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_name) return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e: except Exception:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -147,7 +147,7 @@ async def get_image_metadata(
try: try:
return ApiDependencies.invoker.services.images.get_metadata(image_name) return ApiDependencies.invoker.services.images.get_metadata(image_name)
except Exception as e: except Exception:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -183,7 +183,7 @@ async def get_image_full(
) )
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response return response
except Exception as e: except Exception:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -212,7 +212,7 @@ async def get_image_thumbnail(
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline") response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response return response
except Exception as e: except Exception:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -234,7 +234,7 @@ async def get_image_urls(
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
) )
except Exception as e: except Exception:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -282,10 +282,10 @@ async def delete_images_from_list(
try: try:
ApiDependencies.invoker.services.images.delete(image_name) ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name) deleted_images.append(image_name)
except: except Exception:
pass pass
return DeleteImagesFromListResult(deleted_images=deleted_images) return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images") raise HTTPException(status_code=500, detail="Failed to delete images")
@ -303,10 +303,10 @@ async def star_images_in_list(
try: try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True)) ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
updated_image_names.append(image_name) updated_image_names.append(image_name)
except: except Exception:
pass pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names) return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to star images") raise HTTPException(status_code=500, detail="Failed to star images")
@ -320,8 +320,8 @@ async def unstar_images_in_list(
try: try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False)) ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
updated_image_names.append(image_name) updated_image_names.append(image_name)
except: except Exception:
pass pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names) return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception as e: except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images") raise HTTPException(status_code=500, detail="Failed to unstar images")

View File

@ -1,12 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, List, Optional, Union from typing import Annotated, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic.fields import Field from pydantic.fields import Field
from ...invocations import * # Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...invocations.baseinvocation import BaseInvocation from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import ( from ...services.graph import (
Edge, Edge,

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio import asyncio
import sys
from inspect import signature from inspect import signature
import logging import logging
@ -17,21 +16,11 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path from pathlib import Path
from pydantic.schema import schema from pydantic.schema import schema
# This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before
# other invokeai initialization messages
if app_config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
import mimetypes import mimetypes
@ -40,12 +29,17 @@ from .api.routers import sessions, models, images, boards, board_images, app_inf
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch import torch
import invokeai.backend.util.hotfixes import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
# fix for windows mimetypes registry entries being borked # fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
@ -230,13 +224,16 @@ def invoke_api():
# replace uvicorn's loggers with InvokeAI's for consistent appearance # replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]: for logname in ["uvicorn.access", "uvicorn"]:
l = logging.getLogger(logname) log = logging.getLogger(logname)
l.handlers.clear() log.handlers.clear()
for ch in logger.handlers: for ch in logger.handlers:
l.addHandler(ch) log.addHandler(ch)
loop.run_until_complete(server.serve()) loop.run_until_complete(server.serve())
if __name__ == "__main__": if __name__ == "__main__":
invoke_api() if app_config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_api()

View File

@ -145,10 +145,10 @@ def set_autocompleter(services: InvocationServices) -> Completer:
completer = Completer(services.model_manager) completer = Completer(services.model_manager)
readline.set_completer(completer.complete) readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try: try:
readline.set_auto_history(True) readline.set_auto_history(True)
except: except AttributeError:
# pyreadline3 does not have a set_auto_history() method
pass pass
readline.set_pre_input_hook(completer._pre_input_hook) readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ") readline.set_completer_delims(" ")

View File

@ -13,16 +13,8 @@ from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options # This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before other invokeai initialization messages
if config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage, SqliteBoardImageRecordStorage,
@ -62,10 +54,15 @@ from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
import torch import torch
import invokeai.backend.util.hotfixes import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
class CliCommand(BaseModel): class CliCommand(BaseModel):
@ -482,4 +479,7 @@ def invoke_cli():
if __name__ == "__main__": if __name__ == "__main__":
invoke_cli() if config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_cli()

View File

@ -5,10 +5,10 @@ from typing import Literal
import numpy as np import numpy as np
from pydantic import validator from pydantic import validator
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
@title("Integer Range") @title("Integer Range")

View File

@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from ...backend.model_management import ModelPatcher, ModelType from ...backend.model_management.models import ModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management.models import ModelNotFoundException
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent

View File

@ -29,7 +29,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType, ModelType from ...backend.model_management import BaseModelType
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,

View File

@ -90,7 +90,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return im return im
# Find all invalid tiles and replace with a random valid tile # Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum() replace_count = (tiles_mask is False).sum()
rng = np.random.default_rng(seed=seed) rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :] tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]

View File

@ -15,7 +15,7 @@ from diffusers.models.attention_processor import (
) )
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
@ -32,7 +32,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management import BaseModelType, ModelPatcher from ...backend.model_management.models import BaseModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -47,12 +47,10 @@ from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput,
FieldDescriptions, FieldDescriptions,
Input, Input,
InputField, InputField,
InvocationContext, InvocationContext,
OutputField,
UIType, UIType,
tags, tags,
title, title,

View File

@ -1,5 +1,5 @@
import copy import copy
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@ -16,7 +16,6 @@ from .baseinvocation import (
InputField, InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
UIType,
tags, tags,
title, title,
) )

View File

@ -2,14 +2,13 @@
import inspect import inspect
import re import re
from contextlib import ExitStack
# from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from tqdm import tqdm from tqdm import tqdm
@ -72,7 +71,7 @@ class ONNXPromptInvocation(BaseInvocation):
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
loras = [ loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras for lora in self.clip.loras
@ -259,7 +258,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet, ExitStack() as stack: with unet_info as unet: # , ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [ loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)

View File

@ -38,13 +38,10 @@ from easing_functions import (
SineEaseInOut, SineEaseInOut,
SineEaseOut, SineEaseOut,
) )
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator from matplotlib.ticker import MaxNLocator
from pydantic import BaseModel, Field
from invokeai.app.invocations.primitives import FloatCollectionOutput from invokeai.app.invocations.primitives import FloatCollectionOutput
from ...backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title

View File

@ -1,9 +1,8 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple, Union from typing import Literal, Optional, Tuple
import torch import torch
from anyio import Condition
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from .baseinvocation import ( from .baseinvocation import (

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path from pathlib import Path
from typing import Literal, Union from typing import Literal
import cv2 as cv import cv2 as cv
import numpy as np import numpy as np

View File

@ -1,18 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import List, Union, Optional from typing import Optional
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import ( from invokeai.app.services.board_record_storage import (
BoardRecord, BoardRecord,
BoardRecordStorageBase, BoardRecordStorageBase,
) )
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import ImageRecordStorageBase
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase

View File

@ -1,15 +1,14 @@
from abc import ABC, abstractmethod
from typing import Optional, cast
import sqlite3
import threading import threading
from typing import Optional, Union
import uuid import uuid
from abc import ABC, abstractmethod
from typing import Optional, Union, cast
import sqlite3
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import ( from invokeai.app.services.models.board_record import (
BoardRecord, BoardRecord,
deserialize_board_record, deserialize_board_record,
) )
from pydantic import BaseModel, Field, Extra from pydantic import BaseModel, Field, Extra
@ -230,7 +229,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the name of a board # Change the name of a board
if changes.board_name is not None: if changes.board_name is not None:
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
UPDATE boards UPDATE boards
SET board_name = ? SET board_name = ?
WHERE board_id = ?; WHERE board_id = ?;
@ -241,7 +240,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the cover image of a board # Change the cover image of a board
if changes.cover_image_name is not None: if changes.cover_image_name is not None:
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
UPDATE boards UPDATE boards
SET cover_image_name = ? SET cover_image_name = ?
WHERE board_id = ?; WHERE board_id = ?;

View File

@ -167,7 +167,7 @@ from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig, ListConfig from omegaconf import OmegaConf, DictConfig, ListConfig
from pathlib import Path from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path("invokeai.yaml") INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db") DB_FILE = Path("invokeai.db")
@ -394,7 +394,7 @@ class InvokeAIAppConfig(InvokeAISettings):
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance') max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance') precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
@ -415,8 +415,8 @@ class InvokeAIAppConfig(InvokeAISettings):
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging") log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging") log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging") log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# fmt: on # fmt: on
@ -438,7 +438,7 @@ class InvokeAIAppConfig(InvokeAISettings):
if conf is None: if conf is None:
try: try:
conf = OmegaConf.load(self.root_dir / INIT_FILE) conf = OmegaConf.load(self.root_dir / INIT_FILE)
except: except Exception:
pass pass
InvokeAISettings.initconf = conf InvokeAISettings.initconf = conf
@ -457,7 +457,7 @@ class InvokeAIAppConfig(InvokeAISettings):
""" """
if ( if (
cls.singleton_config is None cls.singleton_config is None
or type(cls.singleton_config) != cls or type(cls.singleton_config) is not cls
or (kwargs and cls.singleton_init != kwargs) or (kwargs and cls.singleton_init != kwargs)
): ):
cls.singleton_config = cls(**kwargs) cls.singleton_config = cls(**kwargs)

View File

@ -9,7 +9,8 @@ import networkx as nx
from pydantic import BaseModel, root_validator, validator from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field from pydantic.fields import Field
from ..invocations import * # Importing * is bad karma but needed here for node detection
from ..invocations import * # noqa: F401 F403
from ..invocations.baseinvocation import ( from ..invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -445,7 +446,7 @@ class Graph(BaseModel):
node = graph.nodes[node_id] node = graph.nodes[node_id]
# Ensure the node type matches the new node # Ensure the node type matches the new node
if type(node) != type(new_node): if type(node) is not type(new_node):
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}") raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
# Ensure the new id is either the same or is not in the graph # Ensure the new id is either the same or is not in the graph
@ -632,7 +633,7 @@ class Graph(BaseModel):
[ [
t t
for input_field in input_fields for input_field in input_fields
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType if t != NoneType
] ]
) # Get unique types ) # Get unique types
@ -923,7 +924,7 @@ class GraphExecutionState(BaseModel):
None, None,
) )
if next_node_id == None: if next_node_id is None:
return None return None
# Get all parents of the next node # Get all parents of the next node

View File

@ -179,7 +179,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: Path, image: PILImageType): def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache: if image_name not in self.__cache:
self.__cache[image_name] = image self.__cache[image_name] = image
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size: if len(self.__cache) > self.__max_cache_size:

View File

@ -282,7 +282,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
SELECT images.metadata FROM images SELECT images.metadata FROM images
WHERE image_name = ?; WHERE image_name = ?;
""", """,
@ -309,7 +309,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the category of the image # Change the category of the image
if changes.image_category is not None: if changes.image_category is not None:
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
UPDATE images UPDATE images
SET image_category = ? SET image_category = ?
WHERE image_name = ?; WHERE image_name = ?;
@ -320,7 +320,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the session associated with the image # Change the session associated with the image
if changes.session_id is not None: if changes.session_id is not None:
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
UPDATE images UPDATE images
SET session_id = ? SET session_id = ?
WHERE image_name = ?; WHERE image_name = ?;
@ -331,7 +331,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `is_intermediate`` flag # Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None: if changes.is_intermediate is not None:
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
UPDATE images UPDATE images
SET is_intermediate = ? SET is_intermediate = ?
WHERE image_name = ?; WHERE image_name = ?;
@ -342,7 +342,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `starred`` state # Change the image's `starred`` state
if changes.starred is not None: if changes.starred is not None:
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
UPDATE images UPDATE images
SET starred = ? SET starred = ?
WHERE image_name = ?; WHERE image_name = ?;

View File

@ -1,4 +1,3 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
@ -379,10 +378,10 @@ class ImageService(ImageServiceABC):
self._services.image_files.delete(image_name) self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name) self._services.image_records.delete(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error("Failed to delete image record")
raise raise
except ImageFileDeleteException: except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file") self._services.logger.error("Failed to delete image file")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image record and file") self._services.logger.error("Problem deleting image record and file")
@ -395,10 +394,10 @@ class ImageService(ImageServiceABC):
self._services.image_files.delete(image_name) self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names) self._services.image_records.delete_many(image_names)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image records") self._services.logger.error("Failed to delete image records")
raise raise
except ImageFileDeleteException: except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image files") self._services.logger.error("Failed to delete image files")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image records and files") self._services.logger.error("Problem deleting image records and files")
@ -412,10 +411,10 @@ class ImageService(ImageServiceABC):
self._services.image_files.delete(image_name) self._services.image_files.delete(image_name)
return count return count
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image records") self._services.logger.error("Failed to delete image records")
raise raise
except ImageFileDeleteException: except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image files") self._services.logger.error("Failed to delete image files")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image records and files") self._services.logger.error("Problem deleting image records and files")

View File

@ -7,6 +7,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase

View File

@ -1,7 +1,6 @@
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com> # Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
"""Utility to collect execution time and GPU usage stats on invocations in flight""" """Utility to collect execution time and GPU usage stats on invocations in flight
"""
Usage: Usage:
statistics = InvocationStatsService(graph_execution_manager) statistics = InvocationStatsService(graph_execution_manager)

View File

@ -60,7 +60,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
return None if name not in self.__cache else self.__cache[name] return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor): def __set_cache(self, name: str, data: torch.Tensor):
if not name in self.__cache: if name not in self.__cache:
self.__cache[name] = data self.__cache[name] = data
self.__cache_ids.put(name) self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size: if self.__cache_ids.qsize() > self.__max_cache_size:

View File

@ -1,3 +1,4 @@
from typing import Union
import torch import torch
import numpy as np import numpy as np
import cv2 import cv2
@ -5,7 +6,7 @@ from PIL import Image
from diffusers.utils import PIL_INTERPOLATION from diffusers.utils import PIL_INTERPOLATION
from einops import rearrange from einops import rearrange
from controlnet_aux.util import HWC3, resize_image from controlnet_aux.util import HWC3
################################################################### ###################################################################
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet # Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
@ -232,7 +233,8 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
k0 = float(h) / old_h k0 = float(h) / old_h
k1 = float(w) / old_w k1 = float(w) / old_w
safeint = lambda x: int(np.round(x)) def safeint(x: Union[int, float]) -> int:
return int(np.round(x))
# if resize_mode == external_code.ResizeMode.OUTER_FIT: # if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == "fill_resize": # OUTER_FIT if resize_mode == "fill_resize": # OUTER_FIT

View File

@ -5,7 +5,6 @@ from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management.models import BaseModelType from ...backend.model_management.models import BaseModelType

View File

@ -1,5 +1,5 @@
""" """
Initialization file for invokeai.backend Initialization file for invokeai.backend
""" """
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo # noqa: F401
from .model_management.models import SilenceWarnings from .model_management.models import SilenceWarnings # noqa: F401

View File

@ -1,14 +1,16 @@
""" """
Initialization file for invokeai.backend.image_util methods. Initialization file for invokeai.backend.image_util methods.
""" """
from .patchmatch import PatchMatch from .patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding from .seamless import configure_model_padding # noqa: F401
from .txt2mask import Txt2Mask from .txt2mask import Txt2Mask # noqa: F401
from .util import InitImageResizer, make_grid from .util import InitImageResizer, make_grid # noqa: F401
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False): def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
from PIL import ImageDraw
if not debug_status: if not debug_status:
return return

View File

@ -26,7 +26,7 @@ class PngWriter:
dirlist = sorted(os.listdir(self.outdir), reverse=True) dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png # find the first filename that matches our pattern or return 000000.0.png
existing_name = next( existing_name = next(
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)), (f for f in dirlist if re.match(r"^(\d+)\..*\.png", f)),
"0000000.0.png", "0000000.0.png",
) )
basecount = int(existing_name.split(".", 1)[0]) + 1 basecount = int(existing_name.split(".", 1)[0]) + 1
@ -98,11 +98,11 @@ class PromptFormatter:
# to do: put model name into the t2i object # to do: put model name into the t2i object
# switches.append(f'--model{t2i.model_name}') # switches.append(f'--model{t2i.model_name}')
if opt.seamless or t2i.seamless: if opt.seamless or t2i.seamless:
switches.append(f"--seamless") switches.append("--seamless")
if opt.init_img: if opt.init_img:
switches.append(f"-I{opt.init_img}") switches.append(f"-I{opt.init_img}")
if opt.fit: if opt.fit:
switches.append(f"--fit") switches.append("--fit")
if opt.strength and opt.init_img is not None: if opt.strength and opt.init_img is not None:
switches.append(f"-f{opt.strength or t2i.strength}") switches.append(f"-f{opt.strength or t2i.strength}")
if opt.gfpgan_strength: if opt.gfpgan_strength:

View File

@ -52,7 +52,6 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns, SingleSelectColumns,
CenteredButtonPress, CenteredButtonPress,
FileBox, FileBox,
IntTitleSlider,
set_min_terminal_size, set_min_terminal_size,
CyclingForm, CyclingForm,
MIN_COLS, MIN_COLS,

View File

@ -116,7 +116,7 @@ class MigrateTo3(object):
appropriate location within the destination models directory. appropriate location within the destination models directory.
""" """
directories_scanned = set() directories_scanned = set()
for root, dirs, files in os.walk(src_dir): for root, dirs, files in os.walk(src_dir, followlinks=True):
for d in dirs: for d in dirs:
try: try:
model = Path(root, d) model = Path(root, d)
@ -525,7 +525,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
if version_3: # write into the dest directory if version_3: # write into the dest directory
try: try:
shutil.copy(dest_directory / "configs" / "models.yaml", config_file) shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
except: except Exception:
MigrateTo3.initialize_yaml(config_file) MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / "models").replace(dest_models) (dest_directory / "models").replace(dest_models)

View File

@ -12,7 +12,6 @@ from typing import Optional, List, Dict, Callable, Union, Set
import requests import requests
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
import onnx
import torch import torch
from huggingface_hub import hf_hub_url, HfFolder, HfApi from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf

View File

@ -1,10 +1,10 @@
""" """
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType # noqa: F401
from .model_cache import ModelCache from .model_cache import ModelCache # noqa: F401
from .lora import ModelPatcher, ONNXModelPatcher from .lora import ModelPatcher, ONNXModelPatcher # noqa: F401
from .models import ( from .models import ( # noqa: F401
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
@ -12,5 +12,4 @@ from .models import (
ModelNotFoundException, ModelNotFoundException,
DuplicateModelException, DuplicateModelException,
) )
from .model_merge import ModelMerger, MergeInterpolationMethod from .model_merge import ModelMerger, MergeInterpolationMethod # noqa: F401
from .lora import ModelPatcher

View File

@ -5,21 +5,16 @@ from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any, Union, List from typing import Optional, Dict, Tuple, Any, Union, List
from pathlib import Path from pathlib import Path
import torch
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from onnx import numpy_helper
from onnxruntime import OrtValue
import numpy as np import numpy as np
import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from .models.lora import LoRAModel
""" """
loras = [ loras = [
(lora_model1, 0.7), (lora_model1, 0.7),
@ -52,7 +47,7 @@ class ModelPatcher:
module = module.get_submodule(submodule_name) module = module.get_submodule(submodule_name)
module_key += "." + submodule_name module_key += "." + submodule_name
submodule_name = key_parts.pop(0) submodule_name = key_parts.pop(0)
except: except Exception:
submodule_name += "_" + key_parts.pop(0) submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name) module = module.get_submodule(submodule_name)
@ -312,7 +307,8 @@ class TextualInversionManager(BaseTextualInversionManager):
class ONNXModelPatcher: class ONNXModelPatcher:
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel from .models.base import IAIOnnxRuntimeModel
from diffusers import OnnxRuntimeModel
@classmethod @classmethod
@contextmanager @contextmanager
@ -341,7 +337,7 @@ class ONNXModelPatcher:
def apply_lora( def apply_lora(
cls, cls,
model: IAIOnnxRuntimeModel, model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoraModel, float]], loras: List[Tuple[LoRAModel, float]],
prefix: str, prefix: str,
): ):
from .models.base import IAIOnnxRuntimeModel from .models.base import IAIOnnxRuntimeModel

View File

@ -273,7 +273,7 @@ class ModelCache(object):
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}") self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats() self.cache._print_cuda_stats()
except: except Exception:
self.cache_entry.unlock() self.cache_entry.unlock()
raise raise

View File

@ -419,12 +419,12 @@ class ModelManager(object):
base_model_str, model_type_str, model_name = model_key.split("/", 2) base_model_str, model_type_str, model_name = model_key.split("/", 2)
try: try:
model_type = ModelType(model_type_str) model_type = ModelType(model_type_str)
except: except Exception:
raise Exception(f"Unknown model type: {model_type_str}") raise Exception(f"Unknown model type: {model_type_str}")
try: try:
base_model = BaseModelType(base_model_str) base_model = BaseModelType(base_model_str)
except: except Exception:
raise Exception(f"Unknown base model: {base_model_str}") raise Exception(f"Unknown base model: {base_model_str}")
return (model_name, base_model, model_type) return (model_name, base_model, model_type)
@ -855,7 +855,7 @@ class ModelManager(object):
info.pop("config") info.pop("config")
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True) result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
except: except Exception:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path) rmtree(new_diffusers_path)
raise raise
@ -1042,7 +1042,7 @@ class ModelManager(object):
# Patch in the SD VAE from core so that it is available for use by the UI # Patch in the SD VAE from core so that it is available for use by the UI
try: try:
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))}) self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
except: except Exception:
pass pass
installer = ModelInstall( installer = ModelInstall(

View File

@ -217,9 +217,9 @@ class ModelProbe(object):
raise "The model {model_name} is potentially infected by malware. Aborting import." raise "The model {model_name} is potentially infected by malware. Aborting import."
###################################################3 # ##################################################3
# Checkpoint probing # Checkpoint probing
###################################################3 # ##################################################3
class ProbeBase(object): class ProbeBase(object):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
pass pass
@ -431,7 +431,7 @@ class PipelineFolderProbe(FolderProbeBase):
return ModelVariantType.Depth return ModelVariantType.Depth
elif in_channels == 4: elif in_channels == 4:
return ModelVariantType.Normal return ModelVariantType.Normal
except: except Exception:
pass pass
return ModelVariantType.Normal return ModelVariantType.Normal

View File

@ -56,7 +56,7 @@ class ModelSearch(ABC):
self.on_search_completed() self.on_search_completed()
def walk_directory(self, path: Path): def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path): for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."): if str(Path(root).name).startswith("."):
self._pruned_paths.add(root) self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import ( from .base import ( # noqa: F401
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
@ -118,7 +118,7 @@ def get_model_config_enums():
fields = model_config.__annotations__ fields = model_config.__annotations__
try: try:
field = fields["model_format"] field = fields["model_format"]
except: except Exception:
raise Exception("format field not found") raise Exception("format field not found")
# model_format: None # model_format: None

View File

@ -3,27 +3,28 @@ import os
import sys import sys
import typing import typing
import inspect import inspect
from enum import Enum import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from contextlib import suppress
from enum import Enum
from pathlib import Path from pathlib import Path
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import torch import torch
import numpy as np import numpy as np
import safetensors.torch
from pathlib import Path
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
import onnx import onnx
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from onnx import numpy_helper from onnx import numpy_helper
from onnxruntime import ( from onnxruntime import (
InferenceSession, InferenceSession,
SessionOptions, SessionOptions,
get_available_providers, get_available_providers,
) )
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -171,7 +172,7 @@ class ModelBase(metaclass=ABCMeta):
fields = value.__annotations__ fields = value.__annotations__
try: try:
field = fields["model_format"] field = fields["model_format"]
except: except Exception:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})") raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
@ -244,7 +245,7 @@ class DiffusersModel(ModelBase):
try: try:
config_data = DiffusionPipeline.load_config(self.model_path) config_data = DiffusionPipeline.load_config(self.model_path)
# config_data = json.loads(os.path.join(self.model_path, "model_index.json")) # config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except: except Exception:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None) config_data.pop("_ignore_files", None)
@ -343,7 +344,7 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
with open(os.path.join(model_path, file), "r") as f: with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read()) index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"]) return int(index_data["metadata"]["total_size"])
except: except Exception:
pass pass
# calculate files size if there is no index file # calculate files size if there is no index file
@ -440,7 +441,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"): if str(path).endswith(".safetensors"):
try: try:
checkpoint = _fast_safetensors_reader(path) checkpoint = _fast_safetensors_reader(path)
except: except Exception:
# TODO: create issue for support "meta"? # TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu") checkpoint = safetensors.torch.load_file(path, device="cpu")
else: else:
@ -452,11 +453,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
return checkpoint return checkpoint
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object): class SilenceWarnings(object):
def __init__(self): def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity() self.transformers_verbosity = transformers_logging.get_verbosity()
@ -639,7 +635,7 @@ class IAIOnnxRuntimeModel:
raise Exception("You should call create_session before running model") raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()} inputs = {k: np.array(v) for k, v in kwargs.items()}
output_names = self.session.get_outputs() # output_names = self.session.get_outputs()
# for k in inputs: # for k in inputs:
# self.io_binding.bind_cpu_input(k, inputs[k]) # self.io_binding.bind_cpu_input(k, inputs[k])
# for name in output_names: # for name in output_names:

View File

@ -43,7 +43,7 @@ class ControlNetModel(ModelBase):
try: try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
# config = json.loads(os.path.join(self.model_path, "config.json")) # config = json.loads(os.path.join(self.model_path, "config.json"))
except: except Exception:
raise Exception("Invalid controlnet model! (config.json not found or invalid)") raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None) model_class_name = config.get("_class_name", None)
@ -53,7 +53,7 @@ class ControlNetModel(ModelBase):
try: try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path) self.model_size = calc_model_size_by_fs(self.model_path)
except: except Exception:
raise Exception("Invalid ControlNet model!") raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None): def get_size(self, child_type: Optional[SubModelType] = None):
@ -78,7 +78,7 @@ class ControlNetModel(ModelBase):
variant=variant, variant=variant,
) )
break break
except: except Exception:
pass pass
if not model: if not model:
raise ModelNotFoundException() raise ModelNotFoundException()

View File

@ -330,5 +330,5 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
config_path = config_path.relative_to(app_config.root_path) config_path = config_path.relative_to(app_config.root_path)
return str(config_path) return str(config_path)
except: except Exception:
return None return None

View File

@ -1,25 +1,17 @@
import os
import json
from enum import Enum from enum import Enum
from pydantic import Field from typing import Literal
from pathlib import Path
from typing import Literal, Optional, Union from diffusers import OnnxRuntimeModel
from .base import ( from .base import (
ModelBase,
ModelConfigBase, ModelConfigBase,
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType,
ModelVariantType, ModelVariantType,
DiffusersModel, DiffusersModel,
SchedulerPredictionType, SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty, classproperty,
OnnxRuntimeModel,
IAIOnnxRuntimeModel, IAIOnnxRuntimeModel,
) )
from invokeai.app.services.config import InvokeAIAppConfig
class StableDiffusionOnnxModelFormat(str, Enum): class StableDiffusionOnnxModelFormat(str, Enum):

View File

@ -44,14 +44,14 @@ class VaeModel(ModelBase):
try: try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
# config = json.loads(os.path.join(self.model_path, "config.json")) # config = json.loads(os.path.join(self.model_path, "config.json"))
except: except Exception:
raise Exception("Invalid vae model! (config.json not found or invalid)") raise Exception("Invalid vae model! (config.json not found or invalid)")
try: try:
vae_class_name = config.get("_class_name", "AutoencoderKL") vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name]) self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path) self.model_size = calc_model_size_by_fs(self.model_path)
except: except Exception:
raise Exception("Invalid vae model! (Unkown vae type)") raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SubModelType] = None): def get_size(self, child_type: Optional[SubModelType] = None):

View File

@ -1,11 +1,15 @@
""" """
Initialization file for the invokeai.backend.stable_diffusion package Initialization file for the invokeai.backend.stable_diffusion package
""" """
from .diffusers_pipeline import ( from .diffusers_pipeline import ( # noqa: F401
ConditioningData, ConditioningData,
PipelineIntermediateState, PipelineIntermediateState,
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
) )
from .diffusion import InvokeAIDiffuserComponent from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
PostprocessingSettings,
BasicConditioningInfo,
SDXLConditioningInfo,
)

View File

@ -2,10 +2,8 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import math
import secrets
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, Union from typing import Any, Callable, List, Optional, Union
import PIL.Image import PIL.Image
import einops import einops

View File

@ -1,9 +1,9 @@
""" """
Initialization file for invokeai.models.diffusion Initialization file for invokeai.models.diffusion
""" """
from .cross_attention_control import InvokeAICrossAttentionMixin from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import ( from .shared_invokeai_diffusion import ( # noqa: F401
InvokeAIDiffuserComponent, InvokeAIDiffuserComponent,
PostprocessingSettings, PostprocessingSettings,
BasicConditioningInfo, BasicConditioningInfo,

View File

@ -4,6 +4,7 @@
import enum import enum
import math import math
from dataclasses import dataclass, field
from typing import Callable, Optional from typing import Callable, Optional
import diffusers import diffusers
@ -12,6 +13,11 @@ import torch
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
SlicedAttnProcessor,
)
from torch import nn from torch import nn
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -522,14 +528,6 @@ class AttnProcessor:
return hidden_states return hidden_states
""" """
from dataclasses import dataclass, field
import torch
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
SlicedAttnProcessor,
)
@dataclass @dataclass

View File

@ -5,8 +5,6 @@ import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from .cross_attention_control import CrossAttentionType, get_cross_attention_modules
class AttentionMapSaver: class AttentionMapSaver:
def __init__(self, token_ids: range, latents_shape: torch.Size): def __init__(self, token_ids: range, latents_shape: torch.Size):

View File

@ -3,15 +3,12 @@ from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
import math import math
from typing import Any, Callable, Dict, Optional, Union, List from typing import Any, Callable, Optional, Union
import numpy as np
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .cross_attention_control import ( from .cross_attention_control import (
@ -579,7 +576,7 @@ class InvokeAIDiffuserComponent:
latents.to(device="cpu") latents.to(device="cpu")
if ( if (
h_symmetry_time_pct != None h_symmetry_time_pct is not None
and self.last_percent_through < h_symmetry_time_pct and self.last_percent_through < h_symmetry_time_pct
and percent_through >= h_symmetry_time_pct and percent_through >= h_symmetry_time_pct
): ):
@ -595,7 +592,7 @@ class InvokeAIDiffuserComponent:
) )
if ( if (
v_symmetry_time_pct != None v_symmetry_time_pct is not None
and self.last_percent_through < v_symmetry_time_pct and self.last_percent_through < v_symmetry_time_pct
and percent_through >= v_symmetry_time_pct and percent_through >= v_symmetry_time_pct
): ):

View File

@ -1,6 +1,6 @@
from ldm.modules.image_degradation.bsrgan import ( from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
degradation_bsrgan_variant as degradation_fn_bsr, degradation_bsrgan_variant as degradation_fn_bsr,
) )
from ldm.modules.image_degradation.bsrgan_light import ( from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
degradation_bsrgan_variant as degradation_fn_bsr_light, degradation_bsrgan_variant as degradation_fn_bsr_light,
) )

View File

@ -573,14 +573,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
""" """
image = util.uint2single(image) image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 jpeg_prob, scale2_prob = 0.9, 0.25
sf_ori = sf # isp_prob = 0.25 # uncomment with `if i== 6` block below
# sf_ori = sf # uncomment with `if i== 6` block below
h1, w1 = image.shape[:2] h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2] h, w = image.shape[:2]
hq = image.copy() # hq = image.copy() # uncomment with `if i== 6` block below
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
@ -777,7 +778,7 @@ if __name__ == "__main__":
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape) print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape) print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape) # print(img_hq.shape)
lq_nearest = cv2.resize( lq_nearest = cv2.resize(
util.single2uint(img_lq), util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
@ -788,5 +789,6 @@ if __name__ == "__main__":
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0, interpolation=0,
) )
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) # img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
util.imsave(img_concat, str(i) + ".png") util.imsave(img_concat, str(i) + ".png")

View File

@ -577,14 +577,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
""" """
image = util.uint2single(image) image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 jpeg_prob, scale2_prob = 0.9, 0.25
sf_ori = sf # isp_prob = 0.25 # uncomment with `if i== 6` block below
# sf_ori = sf # uncomment with `if i== 6` block below
h1, w1 = image.shape[:2] h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2] h, w = image.shape[:2]
hq = image.copy() # hq = image.copy() # uncomment with `if i== 6` block below
if sf == 4 and random.random() < scale2_prob: # downsample1 if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5: if np.random.rand() < 0.5:

View File

@ -8,8 +8,6 @@ import numpy as np
import torch import torch
from torchvision.utils import make_grid from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@ -50,6 +48,8 @@ def get_timestamp():
def imshow(x, title=None, cbar=False, figsize=None): def imshow(x, title=None, cbar=False, figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize) plt.figure(figsize=figsize)
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray") plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
if title: if title:
@ -60,6 +60,8 @@ def imshow(x, title=None, cbar=False, figsize=None):
def surf(Z, cmap="rainbow", figsize=None): def surf(Z, cmap="rainbow", figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize) plt.figure(figsize=figsize)
ax3 = plt.axes(projection="3d") ax3 = plt.axes(projection="3d")
@ -89,7 +91,7 @@ def get_image_paths(dataroot):
def _get_paths_from_images(path): def _get_paths_from_images(path):
assert os.path.isdir(path), "{:s} is not a valid directory".format(path) assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
images = [] images = []
for dirpath, _, fnames in sorted(os.walk(path)): for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
for fname in sorted(fnames): for fname in sorted(fnames):
if is_image_file(fname): if is_image_file(fname):
img_path = os.path.join(dirpath, fname) img_path = os.path.join(dirpath, fname)

View File

@ -1 +1 @@
from .schedulers import SCHEDULER_MAP from .schedulers import SCHEDULER_MAP # noqa: F401

View File

@ -1,4 +1,4 @@
""" """
Initialization file for invokeai.backend.training Initialization file for invokeai.backend.training
""" """
from .textual_inversion_training import do_textual_inversion_training, parse_args from .textual_inversion_training import do_textual_inversion_training, parse_args # noqa: F401

View File

@ -1,7 +1,7 @@
""" """
Initialization file for invokeai.backend.util Initialization file for invokeai.backend.util
""" """
from .devices import ( from .devices import ( # noqa: F401
CPU_DEVICE, CPU_DEVICE,
CUDA_DEVICE, CUDA_DEVICE,
MPS_DEVICE, MPS_DEVICE,
@ -10,5 +10,5 @@ from .devices import (
normalize_device, normalize_device,
torch_dtype, torch_dtype,
) )
from .log import write_log from .log import write_log # noqa: F401
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401

View File

@ -25,10 +25,15 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from invokeai.backend.util.logging import InvokeAILogger
# TODO: create PR to diffusers # TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added # Modified ControlNetModel with encoder_attention_mask argument added
logger = InvokeAILogger.getLogger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
""" """
A ControlNet model. A ControlNet model.
@ -111,7 +116,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
"DownBlock2D", "DownBlock2D",
), ),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2, layers_per_block: int = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,

View File

@ -27,8 +27,8 @@ def write_log_message(results, output_cntr):
log_lines = [f"{path}: {prompt}\n" for path, prompt in results] log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
if len(log_lines) > 1: if len(log_lines) > 1:
subcntr = 1 subcntr = 1
for l in log_lines: for ll in log_lines:
print(f"[{output_cntr}.{subcntr}] {l}", end="") print(f"[{output_cntr}.{subcntr}] {ll}", end="")
subcntr += 1 subcntr += 1
else: else:
print(f"[{output_cntr}] {log_lines[0]}", end="") print(f"[{output_cntr}] {log_lines[0]}", end="")

View File

@ -182,13 +182,13 @@ import urllib.parse
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
try: try:
import syslog import syslog
SYSLOG_AVAILABLE = True SYSLOG_AVAILABLE = True
except: except ImportError:
SYSLOG_AVAILABLE = False SYSLOG_AVAILABLE = False
@ -417,7 +417,7 @@ class InvokeAILogger(object):
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]] syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
else: else:
syslog_args["address"] = arg_name syslog_args["address"] = arg_name
except: except Exception:
raise ValueError(f"{args} is not a value argument list for syslog logging") raise ValueError(f"{args} is not a value argument list for syslog logging")
return logging.handlers.SysLogHandler(**syslog_args) return logging.handlers.SysLogHandler(**syslog_args)

View File

@ -191,7 +191,7 @@ class ChunkedSlicedAttnProcessor:
assert value.shape[0] == 1 assert value.shape[0] == 1
assert hidden_states.shape[0] == 1 assert hidden_states.shape[0] == 1
dtype = query.dtype # dtype = query.dtype
if attn.upcast_attention: if attn.upcast_attention:
query = query.float() query = query.float()
key = key.float() key = key.float()

View File

@ -84,7 +84,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config, **kwargs): def instantiate_from_config(config, **kwargs):
if not "target" in config: if "target" not in config:
if config == "__is_first_stage__": if config == "__is_first_stage__":
return None return None
elif config == "__is_unconditional__": elif config == "__is_unconditional__":
@ -234,16 +234,17 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
.repeat_interleave(d[1], 1) .repeat_interleave(d[1], 1)
) )
dot = lambda grad, shift: ( def dot(grad, shift):
torch.stack( return (
( torch.stack(
grid[: shape[0], : shape[1], 0] + shift[0], (
grid[: shape[0], : shape[1], 1] + shift[1], grid[: shape[0], : shape[1], 0] + shift[0],
), grid[: shape[0], : shape[1], 1] + shift[1],
dim=-1, ),
) dim=-1,
* grad[: shape[0], : shape[1]] )
).sum(dim=-1) * grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device) n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device) n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
@ -287,7 +288,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
if dest.is_dir(): if dest.is_dir():
try: try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1) file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except: except AttributeError:
file_name = os.path.basename(url) file_name = os.path.basename(url)
dest = dest / file_name dest = dest / file_name
else: else:
@ -342,7 +343,7 @@ def url_attachment_name(url: str) -> dict:
resp = requests.get(url, stream=True) resp = requests.get(url, stream=True)
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")) match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
return match.group(1) return match.group(1)
except: except Exception:
return None return None

View File

@ -1,4 +1,4 @@
""" """
Initialization file for invokeai.frontend.CLI Initialization file for invokeai.frontend.CLI
""" """
from .CLI import main as invokeai_command_line_interface from .CLI import main as invokeai_command_line_interface # noqa: F401

View File

@ -1,4 +1,4 @@
""" """
Wrapper for invokeai.backend.configure.invokeai_configure Wrapper for invokeai.backend.configure.invokeai_configure
""" """
from ...backend.install.invokeai_configure import main as invokeai_configure from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401

View File

@ -80,7 +80,7 @@ def welcome(versions: dict):
def get_extras(): def get_extras():
extras = "" extras = ""
try: try:
dist = pkg_resources.get_distribution("xformers") _ = pkg_resources.get_distribution("xformers")
extras = "[xformers]" extras = "[xformers]"
except pkg_resources.DistributionNotFound: except pkg_resources.DistributionNotFound:
pass pass
@ -90,7 +90,7 @@ def get_extras():
def main(): def main():
versions = get_versions() versions = get_versions()
if invokeai_is_running(): if invokeai_is_running():
print(f":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]") print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
input("Press any key to continue...") input("Press any key to continue...")
return return
@ -122,9 +122,9 @@ def main():
print("") print("")
print("") print("")
if os.system(cmd) == 0: if os.system(cmd) == 0:
print(f":heavy_check_mark: Upgrade successful") print(":heavy_check_mark: Upgrade successful")
else: else:
print(f":exclamation: [bold red]Upgrade failed[/red bold]") print(":exclamation: [bold red]Upgrade failed[/red bold]")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -251,7 +251,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
) -> dict[str, npyscreen.widget]: ) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets""" """Generic code to create model selection widgets"""
widgets = dict() widgets = dict()
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and not x in exclude] model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
model_labels = [self.model_labels[x] for x in model_list] model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models) == 0 show_recommended = len(self.installed_models) == 0
@ -357,14 +357,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
try: try:
v.hidden = True v.hidden = True
v.editable = False v.editable = False
except: except Exception:
pass pass
for k, v in widgets[selected_tab].items(): for k, v in widgets[selected_tab].items():
try: try:
v.hidden = False v.hidden = False
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
v.editable = True v.editable = True
except: except Exception:
pass pass
self.__class__.current_tab = selected_tab # for persistence self.__class__.current_tab = selected_tab # for persistence
self.display() self.display()
@ -541,7 +541,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.ti_models, self.ti_models,
] ]
for section in ui_sections: for section in ui_sections:
if not "models_selected" in section: if "models_selected" not in section:
continue continue
selected = set([section["models"][x] for x in section["models_selected"].value]) selected = set([section["models"][x] for x in section["models_selected"].value])
models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_install = [x for x in selected if not self.all_models[x].installed]
@ -637,7 +637,7 @@ def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPre
return None return None
else: else:
return response return response
except: except Exception:
return None return None
@ -673,8 +673,7 @@ def process_and_execute(
def select_and_download_models(opt: Namespace): def select_and_download_models(opt: Namespace):
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
config.precision = precision config.precision = precision
helper = lambda x: ask_user_for_prediction_type(x) installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
installer = ModelInstall(config, prediction_type_helper=helper)
if opt.list_models: if opt.list_models:
installer.list_models(opt.list_models) installer.list_models(opt.list_models)
elif opt.add or opt.delete: elif opt.add or opt.delete:

View File

@ -102,8 +102,8 @@ def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
class IntSlider(npyscreen.Slider): class IntSlider(npyscreen.Slider):
def translate_value(self): def translate_value(self):
stri = "%2d / %2d" % (self.value, self.out_of) stri = "%2d / %2d" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4 length = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(l) stri = stri.rjust(length)
return stri return stri
@ -167,8 +167,8 @@ class FloatSlider(npyscreen.Slider):
# this is supposed to adjust display precision, but doesn't # this is supposed to adjust display precision, but doesn't
def translate_value(self): def translate_value(self):
stri = "%3.2f / %3.2f" % (self.value, self.out_of) stri = "%3.2f / %3.2f" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4 length = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(l) stri = stri.rjust(length)
return stri return stri

View File

@ -1,4 +1,3 @@
import os
import sys import sys
import argparse import argparse

View File

@ -1,4 +1,4 @@
""" """
Initialization file for invokeai.frontend.merge Initialization file for invokeai.frontend.merge
""" """
from .merge_diffusers import main as invokeai_merge_diffusers from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401

View File

@ -9,19 +9,15 @@ import curses
import sys import sys
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Optional
import npyscreen import npyscreen
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from npyscreen import widget from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ( from invokeai.backend.model_management import (
ModelMerger, ModelMerger,
MergeInterpolationMethod,
ModelManager, ModelManager,
ModelType, ModelType,
BaseModelType, BaseModelType,
@ -318,7 +314,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else: else:
return True return True
def get_model_names(self, base_model: BaseModelType = None) -> List[str]: def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
model_names = [ model_names = [
info["model_name"] info["model_name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)

View File

@ -1,4 +1,4 @@
""" """
Initialization file for invokeai.frontend.training Initialization file for invokeai.frontend.training
""" """
from .textual_inversion import main as invokeai_textual_inversion from .textual_inversion import main as invokeai_textual_inversion # noqa: F401

View File

@ -59,7 +59,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
try: try:
default = self.model_names.index(saved_args["model"]) default = self.model_names.index(saved_args["model"])
except: except Exception:
pass pass
self.add_widget_intelligent( self.add_widget_intelligent(
@ -377,7 +377,7 @@ def previous_args() -> dict:
try: try:
conf = OmegaConf.load(conf_file) conf = OmegaConf.load(conf_file)
conf["placeholder_token"] = conf["placeholder_token"].strip("<>") conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
except: except Exception:
conf = None conf = None
return conf return conf

View File

@ -4,7 +4,7 @@ import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { startCase } from 'lodash-es'; import { startCase, upperFirst } from 'lodash-es';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { import {
isAnySessionRejected, isAnySessionRejected,
@ -26,6 +26,7 @@ import {
import { ProgressImage } from 'services/events/types'; import { ProgressImage } from 'services/events/types';
import { makeToast } from '../util/makeToast'; import { makeToast } from '../util/makeToast';
import { LANGUAGES } from './constants'; import { LANGUAGES } from './constants';
import { zPydanticValidationError } from './zodSchemas';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -361,9 +362,24 @@ export const systemSlice = createSlice({
state.progressImage = null; state.progressImage = null;
let errorDescription = undefined; let errorDescription = undefined;
const duration = 5000;
if (action.payload?.status === 422) { if (action.payload?.status === 422) {
errorDescription = 'Validation Error'; const result = zPydanticValidationError.safeParse(action.payload);
if (result.success) {
result.data.error.detail.map((e) => {
state.toastQueue.push(
makeToast({
title: upperFirst(e.msg),
status: 'error',
description: `Path:
${e.loc.slice(3).join('.')}`,
duration,
})
);
});
return;
}
} else if (action.payload?.error) { } else if (action.payload?.error) {
errorDescription = action.payload?.error as string; errorDescription = action.payload?.error as string;
} }
@ -373,6 +389,7 @@ export const systemSlice = createSlice({
title: t('toast.serverError'), title: t('toast.serverError'),
status: 'error', status: 'error',
description: errorDescription, description: errorDescription,
duration,
}) })
); );
}); });

View File

@ -0,0 +1,14 @@
import { z } from 'zod';
export const zPydanticValidationError = z.object({
status: z.literal(422),
error: z.object({
detail: z.array(
z.object({
loc: z.array(z.string()),
msg: z.string(),
type: z.string(),
})
),
}),
});

View File

@ -1,7 +1,7 @@
""" """
initialization file for invokeai initialization file for invokeai
""" """
from .invokeai_version import __version__ from .invokeai_version import __version__ # noqa: F401
__app_id__ = "invoke-ai/InvokeAI" __app_id__ = "invoke-ai/InvokeAI"
__app_name__ = "InvokeAI" __app_name__ = "InvokeAI"

View File

@ -8,9 +8,8 @@ from google.colab import files
from IPython.display import Image as ipyimg from IPython.display import Image as ipyimg
import ipywidgets as widgets import ipywidgets as widgets
from PIL import Image from PIL import Image
from numpy import asarray
from einops import rearrange, repeat from einops import rearrange, repeat
import torch, torchvision import torchvision
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap from ldm.util import ismap
import time import time
@ -68,14 +67,14 @@ def get_custom_cond(mode):
elif mode == "text_conditional": elif mode == "text_conditional":
w = widgets.Text(value="A cake with cream!", disabled=True) w = widgets.Text(value="A cake with cream!", disabled=True)
display(w) display(w) # noqa: F821
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f: with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
f.write(w.value) f.write(w.value)
elif mode == "class_conditional": elif mode == "class_conditional":
w = widgets.IntSlider(min=0, max=1000) w = widgets.IntSlider(min=0, max=1000)
display(w) display(w) # noqa: F821
with open(f"{dest}/{mode}/custom.txt", "w") as f: with open(f"{dest}/{mode}/custom.txt", "w") as f:
f.write(w.value) f.write(w.value)
@ -96,7 +95,7 @@ def select_cond_path(mode):
onlyfiles = [f for f in sorted(os.listdir(path))] onlyfiles = [f for f in sorted(os.listdir(path))]
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False) selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
display(selected) display(selected) # noqa: F821
selected_path = os.path.join(path, selected.value) selected_path = os.path.join(path, selected.value)
return selected_path return selected_path
@ -123,7 +122,7 @@ def get_cond(mode, selected_path):
def visualize_cond_img(path): def visualize_cond_img(path):
display(ipyimg(filename=path)) display(ipyimg(filename=path)) # noqa: F821
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None): def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
@ -331,7 +330,7 @@ def make_convolutional_sample(
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except: except Exception:
pass pass
log["sample"] = x_sample log["sample"] = x_sample

View File

@ -95,7 +95,14 @@ dependencies = [
"dev" = [ "dev" = [
"pudb", "pudb",
] ]
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"] "test" = [
"black",
"flake8",
"Flake8-pyproject",
"pytest>6.0.0",
"pytest-cov",
"pytest-datadir",
]
"xformers" = [ "xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'", "xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'", "triton; sys_platform=='linux'",
@ -185,6 +192,8 @@ output = "coverage/index.xml"
[tool.flake8] [tool.flake8]
max-line-length = 120 max-line-length = 120
ignore = ["E203", "E266", "E501", "W503"]
select = ["B", "C", "E", "F", "W", "T4"]
[tool.black] [tool.black]
line-length = 120 line-length = 120

View File

@ -4,7 +4,6 @@ Read a checkpoint/safetensors file and write out a template .json file containin
its metadata for use in fast model probing. its metadata for use in fast model probing.
""" """
import sys
import argparse import argparse
import json import json

View File

@ -3,11 +3,12 @@
import warnings import warnings
from invokeai.app.cli_app import invoke_cli
warnings.warn( warnings.warn(
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API", "dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
DeprecationWarning, DeprecationWarning,
) )
from invokeai.app.cli_app import invoke_cli
invoke_cli() invoke_cli()

View File

@ -2,7 +2,7 @@
"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py""" """This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
import sys import sys
from PIL import Image, PngImagePlugin from PIL import Image
if len(sys.argv) < 2: if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...") print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")

View File

@ -2,13 +2,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
import os
import sys
def main(): def main():
# Change working directory to the repo root # Change working directory to the repo root

View File

@ -2,13 +2,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
import os
import sys
def main(): def main():
# Change working directory to the repo root # Change working directory to the repo root

View File

@ -1,6 +1,7 @@
"""make variations of input image""" """make variations of input image"""
import argparse, os, sys, glob import argparse
import os
import PIL import PIL
import torch import torch
import numpy as np import numpy as np
@ -12,7 +13,6 @@ from einops import rearrange, repeat
from torchvision.utils import make_grid from torchvision.utils import make_grid
from torch import autocast from torch import autocast
from contextlib import nullcontext from contextlib import nullcontext
import time
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -234,7 +234,6 @@ def main():
with torch.no_grad(): with torch.no_grad():
with precision_scope(device.type): with precision_scope(device.type):
with model.ema_scope(): with model.ema_scope():
tic = time.time()
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
@ -279,8 +278,6 @@ def main():
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1 grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")

View File

@ -1,4 +1,6 @@
import argparse, os, sys, glob import argparse
import glob
import os
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm

View File

@ -1,13 +1,13 @@
import argparse, os, sys, glob import argparse
import clip import glob
import os
import torch import torch
import torch.nn as nn
import numpy as np import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange
from torchvision.utils import make_grid from torchvision.utils import make_grid
import scann import scann
import time import time
@ -390,8 +390,8 @@ if __name__ == "__main__":
grid = make_grid(grid, nrow=n_rows) grid = make_grid(grid, nrow=n_rows)
# to image # to image
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() grid_np = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) Image.fromarray(grid_np.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1 grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")

View File

@ -1,24 +1,24 @@
import argparse, os, sys, datetime, glob, importlib, csv import argparse
import datetime
import glob
import os
import sys
import numpy as np import numpy as np
import time import time
import torch import torch
import torchvision import torchvision
import pytorch_lightning as pl import pytorch_lightning as pl
from packaging import version from packaging import version
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch.utils.data import random_split, DataLoader, Dataset, Subset from torch.utils.data import DataLoader, Dataset
from functools import partial from functools import partial
from PIL import Image from PIL import Image
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ( from pytorch_lightning.callbacks import Callback
ModelCheckpoint,
Callback,
LearningRateMonitor,
)
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
@ -651,7 +651,7 @@ if __name__ == "__main__":
trainer_config["accelerator"] = "auto" trainer_config["accelerator"] = "auto"
for k in nondefault_trainer_args(opt): for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k) trainer_config[k] = getattr(opt, k)
if not "gpus" in trainer_config: if "gpus" not in trainer_config:
del trainer_config["accelerator"] del trainer_config["accelerator"]
cpu = True cpu = True
else: else:
@ -803,7 +803,7 @@ if __name__ == "__main__":
trainer_opt.detect_anomaly = False trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ### trainer.logdir = logdir
# data # data
config.data.params.train.params.data_root = opt.data_root config.data.params.train.params.data_root = opt.data_root

View File

@ -2,7 +2,7 @@ from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
from ldm.modules.embedding_manager import EmbeddingManager from ldm.modules.embedding_manager import EmbeddingManager
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
import argparse, os import argparse
from functools import partial from functools import partial
import torch import torch
@ -108,7 +108,7 @@ if __name__ == "__main__":
manager.load(manager_ckpt) manager.load(manager_ckpt)
for placeholder_string in manager.string_to_token_dict: for placeholder_string in manager.string_to_token_dict:
if not placeholder_string in string_to_token_dict: if placeholder_string not in string_to_token_dict:
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]

View File

@ -1,6 +1,12 @@
import argparse, os, sys, glob, datetime, yaml import argparse
import torch import datetime
import glob
import os
import sys
import time import time
import yaml
import torch
import numpy as np import numpy as np
from tqdm import trange from tqdm import trange
@ -10,7 +16,9 @@ from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.0) / 2.0
def rescale(x: float) -> float:
return (x + 1.0) / 2.0
def custom_to_pil(x): def custom_to_pil(x):
@ -45,7 +53,7 @@ def logs2pil(logs, keys=["sample"]):
else: else:
print(f"Unknown format for key {k}. ") print(f"Unknown format for key {k}. ")
img = None img = None
except: except Exception:
img = None img = None
imgs[k] = img imgs[k] = img
return imgs return imgs

View File

@ -1,4 +1,5 @@
import os, sys import os
import sys
import numpy as np import numpy as np
import scann import scann
import argparse import argparse

View File

@ -1,4 +1,5 @@
import argparse, os, sys, glob import argparse
import os
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -7,10 +8,9 @@ from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange from einops import rearrange
from torchvision.utils import make_grid from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from torch import autocast from torch import autocast
from contextlib import contextmanager, nullcontext from contextlib import nullcontext
import k_diffusion as K import k_diffusion as K
import torch.nn as nn import torch.nn as nn
@ -251,7 +251,6 @@ def main():
with torch.no_grad(): with torch.no_grad():
with precision_scope(device.type): with precision_scope(device.type):
with model.ema_scope(): with model.ema_scope():
tic = time.time()
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
@ -310,8 +309,6 @@ def main():
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1 grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")

Some files were not shown because too many files have changed in this diff Show More