mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Resolving merge conflicts for flake8
This commit is contained in:
parent
f6db9da06c
commit
537ae2f901
@ -407,7 +407,7 @@ def get_pip_from_venv(venv_path: Path) -> str:
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
pip = "Scripts\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
return str(venv_path.expanduser().resolve() / pip)
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
inst.install(**args.__dict__)
|
||||
except KeyboardInterrupt as exc:
|
||||
except KeyboardInterrupt:
|
||||
print("\n")
|
||||
print("Ctrl-C pressed. Aborting.")
|
||||
print("Come back soon!")
|
||||
|
@ -70,7 +70,7 @@ def confirm_install(dest: Path) -> bool:
|
||||
)
|
||||
else:
|
||||
print(f"InvokeAI will be installed in {dest}")
|
||||
dest_confirmed = not Confirm.ask(f"Would you like to pick a different location?", default=False)
|
||||
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
@ -90,7 +90,7 @@ def dest_path(dest=None) -> Path:
|
||||
dest = Path(dest).expanduser().resolve()
|
||||
else:
|
||||
dest = Path.cwd().expanduser().resolve()
|
||||
prev_dest = dest.expanduser().resolve()
|
||||
prev_dest = init_path = dest
|
||||
|
||||
dest_confirmed = confirm_install(dest)
|
||||
|
||||
@ -109,9 +109,9 @@ def dest_path(dest=None) -> Path:
|
||||
)
|
||||
|
||||
console.line()
|
||||
print(f"[orange3]Please select the destination directory for the installation:[/] \[{browse_start}]: ")
|
||||
console.print(f"[orange3]Please select the destination directory for the installation:[/] \\[{browse_start}]: ")
|
||||
selected = prompt(
|
||||
f">>> ",
|
||||
">>> ",
|
||||
complete_in_thread=True,
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
@ -134,14 +134,14 @@ def dest_path(dest=None) -> Path:
|
||||
try:
|
||||
dest.mkdir(exist_ok=True, parents=True)
|
||||
return dest
|
||||
except PermissionError as exc:
|
||||
print(
|
||||
except PermissionError:
|
||||
console.print(
|
||||
f"Failed to create directory {dest} due to insufficient permissions",
|
||||
style=Style(color="red"),
|
||||
highlight=True,
|
||||
)
|
||||
except OSError as exc:
|
||||
console.print_exception(exc)
|
||||
except OSError:
|
||||
console.print_exception()
|
||||
|
||||
if Confirm.ask("Would you like to try again?"):
|
||||
dest_path(init_path)
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from logging import Logger
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
@ -45,7 +44,7 @@ def check_internet() -> bool:
|
||||
try:
|
||||
urllib.request.urlopen(host, timeout=1)
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
@ -34,7 +34,7 @@ async def add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ async def remove_image_from_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
|
||||
@ -79,10 +79,10 @@ async def add_images_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
added_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
|
||||
@ -105,8 +105,8 @@ async def remove_images_from_board(
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
@ -37,7 +37,7 @@ async def create_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ async def get_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ async def update_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ async def delete_board(
|
||||
deleted_board_images=deleted_board_images,
|
||||
deleted_images=[],
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete board")
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ async def upload_image(
|
||||
if crop_visible:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except:
|
||||
except Exception:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
@ -73,7 +73,7 @@ async def upload_image(
|
||||
response.headers["Location"] = image_dto.image_url
|
||||
|
||||
return image_dto
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ async def delete_image(
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
@ -97,7 +97,7 @@ async def clear_intermediates() -> int:
|
||||
try:
|
||||
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
|
||||
return count_deleted
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
|
||||
pass
|
||||
|
||||
@ -115,7 +115,7 @@ async def update_image(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ async def get_image_dto(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ async def get_image_metadata(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -183,7 +183,7 @@ async def get_image_full(
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -212,7 +212,7 @@ async def get_image_thumbnail(
|
||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -234,7 +234,7 @@ async def get_image_urls(
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@ -282,10 +282,10 @@ async def delete_images_from_list(
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
@ -303,10 +303,10 @@ async def star_images_in_list(
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@ -320,8 +320,8 @@ async def unstar_images_in_list(
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
@ -1,12 +1,13 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...invocations import *
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import asyncio
|
||||
import sys
|
||||
from inspect import signature
|
||||
|
||||
import logging
|
||||
@ -17,21 +16,11 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
# This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before
|
||||
# other invokeai initialization messages
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
@ -40,12 +29,17 @@ from .api.routers import sessions, models, images, boards, board_images, app_inf
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
@ -230,13 +224,16 @@ def invoke_api():
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
l = logging.getLogger(logname)
|
||||
l.handlers.clear()
|
||||
log = logging.getLogger(logname)
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
l.addHandler(ch)
|
||||
log.addHandler(ch)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_api()
|
||||
|
@ -145,10 +145,10 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(True)
|
||||
except:
|
||||
except AttributeError:
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
pass
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
|
@ -13,16 +13,8 @@ from pydantic.fields import Field
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before other invokeai initialization messages
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
@ -62,10 +54,15 @@ from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
@ -482,4 +479,7 @@ def invoke_cli():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_cli()
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_cli()
|
||||
|
@ -5,10 +5,10 @@ from typing import Literal
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Integer Range")
|
||||
|
@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management import ModelPatcher, ModelType
|
||||
from ...backend.model_management.models import ModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
|
@ -29,7 +29,7 @@ from pydantic import BaseModel, Field, validator
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
|
@ -90,7 +90,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
replace_count = (tiles_mask is False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
|
@ -15,7 +15,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
@ -30,7 +30,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -45,12 +45,10 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
|
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -16,7 +16,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
@ -2,14 +2,13 @@
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import ExitStack
|
||||
|
||||
# from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -72,7 +71,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
@ -259,7 +258,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet, ExitStack() as stack:
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
|
@ -38,13 +38,10 @@ from easing_functions import (
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
from anyio import Condition
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
|
@ -1,18 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union, Optional
|
||||
from typing import Optional
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
|
@ -1,15 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import sqlite3
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import (
|
||||
BoardRecord,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
|
||||
@ -230,7 +229,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
@ -241,7 +240,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
|
@ -167,7 +167,7 @@ from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@ -394,7 +394,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
|
||||
precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
@ -415,8 +415,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
# fmt: on
|
||||
@ -438,7 +438,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
@ -457,7 +457,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) != cls
|
||||
or type(cls.singleton_config) is not cls
|
||||
or (kwargs and cls.singleton_init != kwargs)
|
||||
):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
|
@ -9,7 +9,8 @@ import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ..invocations import * # noqa: F401 F403
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -445,7 +446,7 @@ class Graph(BaseModel):
|
||||
node = graph.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) != type(new_node):
|
||||
if type(node) is not type(new_node):
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
@ -632,7 +633,7 @@ class Graph(BaseModel):
|
||||
[
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
]
|
||||
) # Get unique types
|
||||
@ -923,7 +924,7 @@ class GraphExecutionState(BaseModel):
|
||||
None,
|
||||
)
|
||||
|
||||
if next_node_id == None:
|
||||
if next_node_id is None:
|
||||
return None
|
||||
|
||||
# Get all parents of the next node
|
||||
|
@ -179,7 +179,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
if image_name not in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
|
@ -282,7 +282,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
SELECT images.metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
@ -309,7 +309,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
@ -320,7 +320,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
@ -331,7 +331,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
@ -342,7 +342,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@ -379,10 +378,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
self._services.logger.error("Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
self._services.logger.error("Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
@ -395,10 +394,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image records")
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image files")
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
@ -412,10 +411,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
return count
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image records")
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image files")
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
|
@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
|
@ -60,7 +60,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if not name in self.__cache:
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
|
@ -1,3 +1,4 @@
|
||||
from typing import Union
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
@ -5,7 +6,7 @@ from PIL import Image
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
|
||||
from einops import rearrange
|
||||
from controlnet_aux.util import HWC3, resize_image
|
||||
from controlnet_aux.util import HWC3
|
||||
|
||||
###################################################################
|
||||
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
||||
@ -232,7 +233,8 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
||||
k0 = float(h) / old_h
|
||||
k1 = float(w) / old_w
|
||||
|
||||
safeint = lambda x: int(np.round(x))
|
||||
def safeint(x: Union[int, float]) -> int:
|
||||
return int(np.round(x))
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
if resize_mode == "fill_resize": # OUTER_FIT
|
||||
|
@ -5,7 +5,6 @@ from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from .model_management.models import SilenceWarnings
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo # noqa: F401
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
|
@ -1,14 +1,16 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
from .patchmatch import PatchMatch
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata
|
||||
from .seamless import configure_model_padding
|
||||
from .txt2mask import Txt2Mask
|
||||
from .util import InitImageResizer, make_grid
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
from .txt2mask import Txt2Mask # noqa: F401
|
||||
from .util import InitImageResizer, make_grid # noqa: F401
|
||||
|
||||
|
||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
||||
from PIL import ImageDraw
|
||||
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
|
@ -26,7 +26,7 @@ class PngWriter:
|
||||
dirlist = sorted(os.listdir(self.outdir), reverse=True)
|
||||
# find the first filename that matches our pattern or return 000000.0.png
|
||||
existing_name = next(
|
||||
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
|
||||
(f for f in dirlist if re.match(r"^(\d+)\..*\.png", f)),
|
||||
"0000000.0.png",
|
||||
)
|
||||
basecount = int(existing_name.split(".", 1)[0]) + 1
|
||||
@ -98,11 +98,11 @@ class PromptFormatter:
|
||||
# to do: put model name into the t2i object
|
||||
# switches.append(f'--model{t2i.model_name}')
|
||||
if opt.seamless or t2i.seamless:
|
||||
switches.append(f"--seamless")
|
||||
switches.append("--seamless")
|
||||
if opt.init_img:
|
||||
switches.append(f"-I{opt.init_img}")
|
||||
if opt.fit:
|
||||
switches.append(f"--fit")
|
||||
switches.append("--fit")
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f"-f{opt.strength or t2i.strength}")
|
||||
if opt.gfpgan_strength:
|
||||
|
@ -52,7 +52,6 @@ from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
CenteredButtonPress,
|
||||
FileBox,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
|
@ -525,7 +525,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||
except:
|
||||
except Exception:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / "models").replace(dest_models)
|
||||
|
@ -12,7 +12,6 @@ from typing import Optional, List, Dict, Callable, Union, Set
|
||||
import requests
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
import onnx
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .models import (
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType # noqa: F401
|
||||
from .model_cache import ModelCache # noqa: F401
|
||||
from .lora import ModelPatcher, ONNXModelPatcher # noqa: F401
|
||||
from .models import ( # noqa: F401
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
@ -12,5 +12,4 @@ from .models import (
|
||||
ModelNotFoundException,
|
||||
DuplicateModelException,
|
||||
)
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
from .lora import ModelPatcher
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod # noqa: F401
|
||||
|
@ -5,21 +5,16 @@ from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import OrtValue
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .models.lora import LoRAModel
|
||||
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
@ -52,7 +47,7 @@ class ModelPatcher:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except:
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
@ -341,7 +336,7 @@ class ONNXModelPatcher:
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
|
@ -273,7 +273,7 @@ class ModelCache(object):
|
||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
self.cache_entry.unlock()
|
||||
raise
|
||||
|
||||
|
@ -419,12 +419,12 @@ class ModelManager(object):
|
||||
base_model_str, model_type_str, model_name = model_key.split("/", 2)
|
||||
try:
|
||||
model_type = ModelType(model_type_str)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(f"Unknown model type: {model_type_str}")
|
||||
|
||||
try:
|
||||
base_model = BaseModelType(base_model_str)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(f"Unknown base model: {base_model_str}")
|
||||
|
||||
return (model_name, base_model, model_type)
|
||||
@ -855,7 +855,7 @@ class ModelManager(object):
|
||||
info.pop("config")
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
|
||||
except:
|
||||
except Exception:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
@ -1042,7 +1042,7 @@ class ModelManager(object):
|
||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||
try:
|
||||
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
installer = ModelInstall(
|
||||
|
@ -217,9 +217,9 @@ class ModelProbe(object):
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
|
||||
|
||||
###################################################3
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
###################################################3
|
||||
# ##################################################3
|
||||
class ProbeBase(object):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
@ -431,7 +431,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import (
|
||||
from .base import ( # noqa: F401
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
@ -118,7 +118,7 @@ def get_model_config_enums():
|
||||
fields = model_config.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("format field not found")
|
||||
|
||||
# model_format: None
|
||||
|
@ -3,27 +3,28 @@ import os
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
from enum import Enum
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
import onnx
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import (
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -171,7 +172,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
fields = value.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
@ -244,7 +245,7 @@ class DiffusersModel(ModelBase):
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||
|
||||
config_data.pop("_ignore_files", None)
|
||||
@ -343,7 +344,7 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
||||
with open(os.path.join(model_path, file), "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# calculate files size if there is no index file
|
||||
@ -440,7 +441,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
except:
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
@ -452,11 +453,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
return checkpoint
|
||||
|
||||
|
||||
import warnings
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
@ -639,7 +635,7 @@ class IAIOnnxRuntimeModel:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
output_names = self.session.get_outputs()
|
||||
# output_names = self.session.get_outputs()
|
||||
# for k in inputs:
|
||||
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
# for name in output_names:
|
||||
|
@ -43,7 +43,7 @@ class ControlNetModel(ModelBase):
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||
|
||||
model_class_name = config.get("_class_name", None)
|
||||
@ -53,7 +53,7 @@ class ControlNetModel(ModelBase):
|
||||
try:
|
||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid ControlNet model!")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
@ -78,7 +78,7 @@ class ControlNetModel(ModelBase):
|
||||
variant=variant,
|
||||
)
|
||||
break
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
if not model:
|
||||
raise ModelNotFoundException()
|
||||
|
@ -330,5 +330,5 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
config_path = config_path.relative_to(app_config.root_path)
|
||||
return str(config_path)
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
@ -1,25 +1,16 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
OnnxRuntimeModel,
|
||||
IAIOnnxRuntimeModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class StableDiffusionOnnxModelFormat(str, Enum):
|
||||
|
@ -44,14 +44,14 @@ class VaeModel(ModelBase):
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||
|
||||
try:
|
||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
|
@ -1,11 +1,15 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .diffusers_pipeline import (
|
||||
from .diffusers_pipeline import ( # noqa: F401
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
||||
PostprocessingSettings,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
@ -2,10 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, Union
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .shared_invokeai_diffusion import (
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import ( # noqa: F401
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
BasicConditioningInfo,
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
import enum
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
import diffusers
|
||||
@ -12,6 +13,11 @@ import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -522,14 +528,6 @@ class AttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -5,8 +5,6 @@ import torch
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from .cross_attention_control import CrossAttentionType, get_cross_attention_modules
|
||||
|
||||
|
||||
class AttentionMapSaver:
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
|
@ -3,15 +3,12 @@ from __future__ import annotations
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .cross_attention_control import (
|
||||
@ -579,7 +576,7 @@ class InvokeAIDiffuserComponent:
|
||||
latents.to(device="cpu")
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None
|
||||
h_symmetry_time_pct is not None
|
||||
and self.last_percent_through < h_symmetry_time_pct
|
||||
and percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
@ -595,7 +592,7 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None
|
||||
v_symmetry_time_pct is not None
|
||||
and self.last_percent_through < v_symmetry_time_pct
|
||||
and percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
|
@ -1,6 +1,6 @@
|
||||
from ldm.modules.image_degradation.bsrgan import (
|
||||
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
|
||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
||||
)
|
||||
from ldm.modules.image_degradation.bsrgan_light import (
|
||||
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
|
||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
||||
)
|
||||
|
@ -573,14 +573,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
@ -777,7 +778,7 @@ if __name__ == "__main__":
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
# print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(
|
||||
util.single2uint(img_lq),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
@ -788,5 +789,6 @@ if __name__ == "__main__":
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
||||
|
@ -577,14 +577,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
|
@ -8,8 +8,6 @@ import numpy as np
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
@ -50,6 +48,8 @@ def get_timestamp():
|
||||
|
||||
|
||||
def imshow(x, title=None, cbar=False, figsize=None):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
||||
if title:
|
||||
@ -60,6 +60,8 @@ def imshow(x, title=None, cbar=False, figsize=None):
|
||||
|
||||
|
||||
def surf(Z, cmap="rainbow", figsize=None):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
ax3 = plt.axes(projection="3d")
|
||||
|
||||
|
@ -1 +1 @@
|
||||
from .schedulers import SCHEDULER_MAP
|
||||
from .schedulers import SCHEDULER_MAP # noqa: F401
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.training
|
||||
"""
|
||||
from .textual_inversion_training import do_textual_inversion_training, parse_args
|
||||
from .textual_inversion_training import do_textual_inversion_training, parse_args # noqa: F401
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
from .devices import (
|
||||
from .devices import ( # noqa: F401
|
||||
CPU_DEVICE,
|
||||
CUDA_DEVICE,
|
||||
MPS_DEVICE,
|
||||
@ -10,5 +10,5 @@ from .devices import (
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
)
|
||||
from .log import write_log
|
||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir
|
||||
from .log import write_log # noqa: F401
|
||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401
|
||||
|
@ -25,10 +25,15 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
import diffusers
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# TODO: create PR to diffusers
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
@ -111,7 +116,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
"DownBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
@ -27,8 +27,8 @@ def write_log_message(results, output_cntr):
|
||||
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
if len(log_lines) > 1:
|
||||
subcntr = 1
|
||||
for l in log_lines:
|
||||
print(f"[{output_cntr}.{subcntr}] {l}", end="")
|
||||
for ll in log_lines:
|
||||
print(f"[{output_cntr}.{subcntr}] {ll}", end="")
|
||||
subcntr += 1
|
||||
else:
|
||||
print(f"[{output_cntr}] {log_lines[0]}", end="")
|
||||
|
@ -182,13 +182,13 @@ import urllib.parse
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
try:
|
||||
import syslog
|
||||
|
||||
SYSLOG_AVAILABLE = True
|
||||
except:
|
||||
except ImportError:
|
||||
SYSLOG_AVAILABLE = False
|
||||
|
||||
|
||||
@ -417,7 +417,7 @@ class InvokeAILogger(object):
|
||||
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
|
||||
else:
|
||||
syslog_args["address"] = arg_name
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError(f"{args} is not a value argument list for syslog logging")
|
||||
return logging.handlers.SysLogHandler(**syslog_args)
|
||||
|
||||
|
@ -191,7 +191,7 @@ class ChunkedSlicedAttnProcessor:
|
||||
assert value.shape[0] == 1
|
||||
assert hidden_states.shape[0] == 1
|
||||
|
||||
dtype = query.dtype
|
||||
# dtype = query.dtype
|
||||
if attn.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
@ -84,7 +84,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
@ -234,16 +234,17 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
|
||||
dot = lambda grad, shift: (
|
||||
torch.stack(
|
||||
(
|
||||
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||
grid[: shape[0], : shape[1], 1] + shift[1],
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
def dot(grad, shift):
|
||||
return (
|
||||
torch.stack(
|
||||
(
|
||||
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||
grid[: shape[0], : shape[1], 1] + shift[1],
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
||||
@ -287,7 +288,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
if dest.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except:
|
||||
except AttributeError:
|
||||
file_name = os.path.basename(url)
|
||||
dest = dest / file_name
|
||||
else:
|
||||
@ -342,7 +343,7 @@ def url_attachment_name(url: str) -> dict:
|
||||
resp = requests.get(url, stream=True)
|
||||
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||
return match.group(1)
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.CLI
|
||||
"""
|
||||
from .CLI import main as invokeai_command_line_interface
|
||||
from .CLI import main as invokeai_command_line_interface # noqa: F401
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
"""
|
||||
from ...backend.install.invokeai_configure import main as invokeai_configure
|
||||
from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401
|
||||
|
@ -80,7 +80,7 @@ def welcome(versions: dict):
|
||||
def get_extras():
|
||||
extras = ""
|
||||
try:
|
||||
dist = pkg_resources.get_distribution("xformers")
|
||||
_ = pkg_resources.get_distribution("xformers")
|
||||
extras = "[xformers]"
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
@ -90,7 +90,7 @@ def get_extras():
|
||||
def main():
|
||||
versions = get_versions()
|
||||
if invokeai_is_running():
|
||||
print(f":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
||||
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
||||
input("Press any key to continue...")
|
||||
return
|
||||
|
||||
@ -122,9 +122,9 @@ def main():
|
||||
print("")
|
||||
print("")
|
||||
if os.system(cmd) == 0:
|
||||
print(f":heavy_check_mark: Upgrade successful")
|
||||
print(":heavy_check_mark: Upgrade successful")
|
||||
else:
|
||||
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||
print(":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -251,7 +251,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
widgets = dict()
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and not x in exclude]
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
@ -357,14 +357,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
try:
|
||||
v.hidden = True
|
||||
v.editable = False
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
for k, v in widgets[selected_tab].items():
|
||||
try:
|
||||
v.hidden = False
|
||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||
v.editable = True
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
self.__class__.current_tab = selected_tab # for persistence
|
||||
self.display()
|
||||
@ -541,7 +541,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.ti_models,
|
||||
]
|
||||
for section in ui_sections:
|
||||
if not "models_selected" in section:
|
||||
if "models_selected" not in section:
|
||||
continue
|
||||
selected = set([section["models"][x] for x in section["models_selected"].value])
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
@ -637,7 +637,7 @@ def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPre
|
||||
return None
|
||||
else:
|
||||
return response
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@ -673,8 +673,7 @@ def process_and_execute(
|
||||
def select_and_download_models(opt: Namespace):
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
helper = lambda x: ask_user_for_prediction_type(x)
|
||||
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
||||
if opt.list_models:
|
||||
installer.list_models(opt.list_models)
|
||||
elif opt.add or opt.delete:
|
||||
|
@ -102,8 +102,8 @@ def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
@ -167,8 +167,8 @@ class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401
|
||||
|
@ -9,19 +9,15 @@ import curses
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import npyscreen
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
ModelManager,
|
||||
ModelType,
|
||||
BaseModelType,
|
||||
@ -318,7 +314,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self, base_model: BaseModelType = None) -> List[str]:
|
||||
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
|
||||
model_names = [
|
||||
info["model_name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.training
|
||||
"""
|
||||
from .textual_inversion import main as invokeai_textual_inversion
|
||||
from .textual_inversion import main as invokeai_textual_inversion # noqa: F401
|
||||
|
@ -59,7 +59,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
@ -377,7 +377,7 @@ def previous_args() -> dict:
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except:
|
||||
except Exception:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
initialization file for invokeai
|
||||
"""
|
||||
from .invokeai_version import __version__
|
||||
from .invokeai_version import __version__ # noqa: F401
|
||||
|
||||
__app_id__ = "invoke-ai/InvokeAI"
|
||||
__app_name__ = "InvokeAI"
|
||||
|
@ -8,9 +8,8 @@ from google.colab import files
|
||||
from IPython.display import Image as ipyimg
|
||||
import ipywidgets as widgets
|
||||
from PIL import Image
|
||||
from numpy import asarray
|
||||
from einops import rearrange, repeat
|
||||
import torch, torchvision
|
||||
import torchvision
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import ismap
|
||||
import time
|
||||
@ -68,14 +67,14 @@ def get_custom_cond(mode):
|
||||
|
||||
elif mode == "text_conditional":
|
||||
w = widgets.Text(value="A cake with cream!", disabled=True)
|
||||
display(w)
|
||||
display(w) # noqa: F821
|
||||
|
||||
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
|
||||
f.write(w.value)
|
||||
|
||||
elif mode == "class_conditional":
|
||||
w = widgets.IntSlider(min=0, max=1000)
|
||||
display(w)
|
||||
display(w) # noqa: F821
|
||||
with open(f"{dest}/{mode}/custom.txt", "w") as f:
|
||||
f.write(w.value)
|
||||
|
||||
@ -96,7 +95,7 @@ def select_cond_path(mode):
|
||||
onlyfiles = [f for f in sorted(os.listdir(path))]
|
||||
|
||||
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
|
||||
display(selected)
|
||||
display(selected) # noqa: F821
|
||||
selected_path = os.path.join(path, selected.value)
|
||||
return selected_path
|
||||
|
||||
@ -123,7 +122,7 @@ def get_cond(mode, selected_path):
|
||||
|
||||
|
||||
def visualize_cond_img(path):
|
||||
display(ipyimg(filename=path))
|
||||
display(ipyimg(filename=path)) # noqa: F821
|
||||
|
||||
|
||||
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
|
||||
@ -331,7 +330,7 @@ def make_convolutional_sample(
|
||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||
log["sample_noquant"] = x_sample_noquant
|
||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
log["sample"] = x_sample
|
||||
|
@ -95,7 +95,14 @@ dependencies = [
|
||||
"dev" = [
|
||||
"pudb",
|
||||
]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
|
||||
"test" = [
|
||||
"black",
|
||||
"flake8",
|
||||
"Flake8-pyproject",
|
||||
"pytest>6.0.0",
|
||||
"pytest-cov",
|
||||
"pytest-datadir",
|
||||
]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
@ -185,6 +192,8 @@ output = "coverage/index.xml"
|
||||
|
||||
[tool.flake8]
|
||||
max-line-length = 120
|
||||
ignore = ["E203", "E266", "E501", "W503"]
|
||||
select = ["B", "C", "E", "F", "W", "T4"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
@ -4,7 +4,6 @@ Read a checkpoint/safetensors file and write out a template .json file containin
|
||||
its metadata for use in fast model probing.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
@ -3,11 +3,12 @@
|
||||
|
||||
import warnings
|
||||
|
||||
from invokeai.app.cli_app import invoke_cli
|
||||
|
||||
warnings.warn(
|
||||
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
from invokeai.app.cli_app import invoke_cli
|
||||
|
||||
invoke_cli()
|
||||
|
@ -2,7 +2,7 @@
|
||||
"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
|
||||
|
||||
import sys
|
||||
from PIL import Image, PngImagePlugin
|
||||
from PIL import Image
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
||||
|
@ -2,13 +2,11 @@
|
||||
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
# Change working directory to the repo root
|
||||
|
@ -2,13 +2,11 @@
|
||||
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
# Change working directory to the repo root
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""make variations of input image"""
|
||||
|
||||
import argparse, os, sys, glob
|
||||
import argparse
|
||||
import os
|
||||
import PIL
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -12,7 +13,6 @@ from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
@ -234,7 +234,6 @@ def main():
|
||||
with torch.no_grad():
|
||||
with precision_scope(device.type):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
@ -279,8 +278,6 @@ def main():
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
||||
grid_count += 1
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
||||
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
import argparse, os, sys, glob
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
@ -1,13 +1,13 @@
|
||||
import argparse, os, sys, glob
|
||||
import clip
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
import scann
|
||||
import time
|
||||
@ -390,8 +390,8 @@ if __name__ == "__main__":
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
||||
grid_np = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
||||
Image.fromarray(grid_np.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
||||
grid_count += 1
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
||||
|
@ -1,24 +1,24 @@
|
||||
import argparse, os, sys, datetime, glob, importlib, csv
|
||||
import argparse
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from packaging import version
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.callbacks import (
|
||||
ModelCheckpoint,
|
||||
Callback,
|
||||
LearningRateMonitor,
|
||||
)
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
@ -651,7 +651,7 @@ if __name__ == "__main__":
|
||||
trainer_config["accelerator"] = "auto"
|
||||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
if not "gpus" in trainer_config:
|
||||
if "gpus" not in trainer_config:
|
||||
del trainer_config["accelerator"]
|
||||
cpu = True
|
||||
else:
|
||||
@ -803,7 +803,7 @@ if __name__ == "__main__":
|
||||
trainer_opt.detect_anomaly = False
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
trainer.logdir = logdir
|
||||
|
||||
# data
|
||||
config.data.params.train.params.data_root = opt.data_root
|
||||
|
@ -2,7 +2,7 @@ from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
||||
from ldm.modules.embedding_manager import EmbeddingManager
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
import argparse, os
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
@ -108,7 +108,7 @@ if __name__ == "__main__":
|
||||
manager.load(manager_ckpt)
|
||||
|
||||
for placeholder_string in manager.string_to_token_dict:
|
||||
if not placeholder_string in string_to_token_dict:
|
||||
if placeholder_string not in string_to_token_dict:
|
||||
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
|
||||
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
|
||||
|
||||
|
@ -1,6 +1,12 @@
|
||||
import argparse, os, sys, glob, datetime, yaml
|
||||
import torch
|
||||
import argparse
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
|
||||
@ -10,7 +16,9 @@ from PIL import Image
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
rescale = lambda x: (x + 1.0) / 2.0
|
||||
|
||||
def rescale(x: float) -> float:
|
||||
return (x + 1.0) / 2.0
|
||||
|
||||
|
||||
def custom_to_pil(x):
|
||||
@ -45,7 +53,7 @@ def logs2pil(logs, keys=["sample"]):
|
||||
else:
|
||||
print(f"Unknown format for key {k}. ")
|
||||
img = None
|
||||
except:
|
||||
except Exception:
|
||||
img = None
|
||||
imgs[k] = img
|
||||
return imgs
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import scann
|
||||
import argparse
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse, os, sys, glob
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
@ -7,10 +8,9 @@ from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import nullcontext
|
||||
|
||||
import k_diffusion as K
|
||||
import torch.nn as nn
|
||||
@ -251,7 +251,6 @@ def main():
|
||||
with torch.no_grad():
|
||||
with precision_scope(device.type):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
@ -310,8 +309,6 @@ def main():
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
||||
grid_count += 1
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
#!/bin/env python
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
||||
|
||||
|
@ -90,17 +90,17 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||
|
||||
def test_graph_is_complete(simple_graph, mock_services):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = g.next()
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = g.next()
|
||||
|
||||
assert g.is_complete()
|
||||
|
||||
|
||||
def test_graph_is_not_complete(simple_graph, mock_services):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = g.next()
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = g.next()
|
||||
|
||||
assert not g.is_complete()
|
||||
|
||||
@ -140,11 +140,11 @@ def test_graph_state_collects(mock_services):
|
||||
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = invoke_next(g, mock_services)
|
||||
n4 = invoke_next(g, mock_services)
|
||||
n5 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
n6 = invoke_next(g, mock_services)
|
||||
|
||||
assert isinstance(n6[0], CollectInvocation)
|
||||
@ -195,10 +195,10 @@ def test_graph_executes_depth_first(mock_services):
|
||||
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = invoke_next(g, mock_services)
|
||||
n4 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||
# Instead, we must count the number of results.
|
||||
@ -211,17 +211,17 @@ def test_graph_executes_depth_first(mock_services):
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 0
|
||||
|
||||
n5 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
n6 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
n7 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 2
|
||||
|
@ -17,7 +17,8 @@ from invokeai.app.services.graph import (
|
||||
IterateInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.invocations.image import *
|
||||
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import IntegerInvocation
|
||||
from invokeai.app.services.default_graphs import create_text_to_image
|
||||
@ -41,7 +42,7 @@ def test_connections_are_compatible():
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
|
||||
assert result == True
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_connections_are_incompatible():
|
||||
@ -52,7 +53,7 @@ def test_connections_are_incompatible():
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
|
||||
assert result == False
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_connections_incompatible_with_invalid_fields():
|
||||
@ -63,14 +64,14 @@ def test_connections_incompatible_with_invalid_fields():
|
||||
|
||||
# From field is invalid
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
assert result == False
|
||||
assert result is False
|
||||
|
||||
# To field is invalid
|
||||
from_field = "image"
|
||||
to_field = "invalid_field"
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
assert result == False
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_graph_can_add_node():
|
||||
@ -394,7 +395,7 @@ def test_graph_validates():
|
||||
e1 = create_edge("1", "image", "2", "image")
|
||||
g.add_edge(e1)
|
||||
|
||||
assert g.is_valid() == True
|
||||
assert g.is_valid() is True
|
||||
|
||||
|
||||
def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||
@ -404,7 +405,7 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||
e1 = create_edge("1", "image", "2", "image")
|
||||
g.edges.append(e1)
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_subgraph_invalid():
|
||||
@ -419,7 +420,7 @@ def test_graph_invalid_if_subgraph_invalid():
|
||||
|
||||
g.nodes[n1.id] = n1
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_has_cycle():
|
||||
@ -433,7 +434,7 @@ def test_graph_invalid_if_has_cycle():
|
||||
g.edges.append(e1)
|
||||
g.edges.append(e2)
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_with_invalid_connection():
|
||||
@ -445,7 +446,7 @@ def test_graph_invalid_with_invalid_connection():
|
||||
e1 = create_edge("1", "image", "2", "strength")
|
||||
g.edges.append(e1)
|
||||
|
||||
assert g.is_valid() == False
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
# TODO: Subgraph operations
|
||||
@ -536,7 +537,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g.add_node(n1)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
result = g.get_node("1.2")
|
||||
_ = g.get_node("1.2")
|
||||
|
||||
|
||||
def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
@ -554,7 +555,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
g.add_node(n2)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
result = g.get_node("2.1")
|
||||
_ = g.get_node("2.1")
|
||||
|
||||
|
||||
def test_graph_gets_networkx_graph():
|
||||
@ -584,7 +585,7 @@ def test_graph_can_serialize():
|
||||
g.add_edge(e)
|
||||
|
||||
# Not throwing on this line is sufficient
|
||||
json = g.json()
|
||||
_ = g.json()
|
||||
|
||||
|
||||
def test_graph_can_deserialize():
|
||||
@ -612,4 +613,4 @@ def test_graph_can_deserialize():
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
schema = Graph.schema_json(indent=2)
|
||||
_ = Graph.schema_json(indent=2)
|
||||
|
@ -1,9 +1,10 @@
|
||||
from typing import Any, Callable, Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from pydantic import Field
|
||||
import pytest
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||
|
||||
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
@ -82,10 +83,6 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||
|
||||
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
return Edge(
|
||||
source=EdgeConnection(node_id=from_id, field=from_field),
|
||||
|
@ -1,13 +1,16 @@
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["INVOKEAI_ROOT"] = "/tmp"
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@pytest.fixture
|
||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||
"""This may be overkill since the current tests don't need the root dir to exist"""
|
||||
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
||||
|
||||
|
||||
init1 = OmegaConf.create(
|
||||
"""
|
||||
@ -32,10 +35,12 @@ InvokeAI:
|
||||
)
|
||||
|
||||
|
||||
def test_use_init():
|
||||
def test_use_init(patch_rootdir):
|
||||
# note that we explicitly set omegaconf dict and argv here
|
||||
# so that the values aren't read from ~invokeai/invokeai.yaml and
|
||||
# sys.argv respectively.
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
conf1 = InvokeAIAppConfig.get_config()
|
||||
assert conf1
|
||||
conf1.parse_args(conf=init1, argv=[])
|
||||
@ -51,7 +56,9 @@ def test_use_init():
|
||||
assert not hasattr(conf2, "invalid_attribute")
|
||||
|
||||
|
||||
def test_argv_override():
|
||||
def test_argv_override(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args(conf=init1, argv=["--always_use_cpu", "--max_cache=10"])
|
||||
assert conf.always_use_cpu
|
||||
@ -59,14 +66,16 @@ def test_argv_override():
|
||||
assert conf.outdir == Path("outputs") # this is the default
|
||||
|
||||
|
||||
def test_env_override():
|
||||
def test_env_override(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# argv overrides
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=init1, argv=["--max_cache=10"])
|
||||
assert conf.always_use_cpu == False
|
||||
assert conf.always_use_cpu is False
|
||||
os.environ["INVOKEAI_always_use_cpu"] = "True"
|
||||
conf.parse_args(conf=init1, argv=["--max_cache=10"])
|
||||
assert conf.always_use_cpu == True
|
||||
assert conf.always_use_cpu is True
|
||||
|
||||
# environment variables should be case insensitive
|
||||
os.environ["InvokeAI_Max_Cache_Size"] = "15"
|
||||
@ -76,7 +85,7 @@ def test_env_override():
|
||||
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=init1, argv=["--no-always_use_cpu", "--max_cache=10"])
|
||||
assert conf.always_use_cpu == False
|
||||
assert conf.always_use_cpu is False
|
||||
assert conf.max_cache_size == 10
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(max_cache_size=20)
|
||||
@ -84,7 +93,9 @@ def test_env_override():
|
||||
assert conf.max_cache_size == 20
|
||||
|
||||
|
||||
def test_root_resists_cwd():
|
||||
def test_root_resists_cwd(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
previous = os.environ["INVOKEAI_ROOT"]
|
||||
cwd = Path(os.getcwd()).resolve()
|
||||
|
||||
@ -99,7 +110,9 @@ def test_root_resists_cwd():
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
def test_type_coercion():
|
||||
def test_type_coercion(patch_rootdir):
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
conf = InvokeAIAppConfig().get_config()
|
||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
||||
assert conf.root == Path("/tmp/foobar")
|
||||
|
Loading…
Reference in New Issue
Block a user