Merge branch 'main' into fix/inpaint_gen

This commit is contained in:
psychedelicious 2023-08-18 15:57:48 +10:00
commit 3c43594c26
106 changed files with 444 additions and 417 deletions

View File

@ -1,6 +1,6 @@
name: style checks
# just formatting for now
# TODO: add isort and flake8 later
# just formatting and flake8 for now
# TODO: add isort later
on:
pull_request:
@ -20,8 +20,8 @@ jobs:
- name: Install dependencies with pip
run: |
pip install black
pip install black flake8 Flake8-pyproject
# - run: isort --check-only .
- run: black --check .
# - run: flake8
- run: flake8

View File

@ -8,3 +8,10 @@ repos:
language: system
entry: black
types: [python]
- id: flake8
name: flake8
stages: [commit]
language: system
entry: flake8
types: [python]

View File

@ -407,7 +407,7 @@ def get_pip_from_venv(venv_path: Path) -> str:
:rtype: str
"""
pip = "Scripts\pip.exe" if OS == "Windows" else "bin/pip"
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
return str(venv_path.expanduser().resolve() / pip)

View File

@ -49,7 +49,7 @@ if __name__ == "__main__":
try:
inst.install(**args.__dict__)
except KeyboardInterrupt as exc:
except KeyboardInterrupt:
print("\n")
print("Ctrl-C pressed. Aborting.")
print("Come back soon!")

View File

@ -70,7 +70,7 @@ def confirm_install(dest: Path) -> bool:
)
else:
print(f"InvokeAI will be installed in {dest}")
dest_confirmed = not Confirm.ask(f"Would you like to pick a different location?", default=False)
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
console.line()
return dest_confirmed
@ -90,7 +90,7 @@ def dest_path(dest=None) -> Path:
dest = Path(dest).expanduser().resolve()
else:
dest = Path.cwd().expanduser().resolve()
prev_dest = dest.expanduser().resolve()
prev_dest = init_path = dest
dest_confirmed = confirm_install(dest)
@ -109,9 +109,9 @@ def dest_path(dest=None) -> Path:
)
console.line()
print(f"[orange3]Please select the destination directory for the installation:[/] \[{browse_start}]: ")
console.print(f"[orange3]Please select the destination directory for the installation:[/] \\[{browse_start}]: ")
selected = prompt(
f">>> ",
">>> ",
complete_in_thread=True,
completer=path_completer,
default=str(browse_start) + os.sep,
@ -134,14 +134,14 @@ def dest_path(dest=None) -> Path:
try:
dest.mkdir(exist_ok=True, parents=True)
return dest
except PermissionError as exc:
print(
except PermissionError:
console.print(
f"Failed to create directory {dest} due to insufficient permissions",
style=Style(color="red"),
highlight=True,
)
except OSError as exc:
console.print_exception(exc)
except OSError:
console.print_exception()
if Confirm.ask("Would you like to try again?"):
dest_path(init_path)

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from logging import Logger
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
@ -45,7 +44,7 @@ def check_internet() -> bool:
try:
urllib.request.urlopen(host, timeout=1)
return True
except:
except Exception:
return False

View File

@ -34,7 +34,7 @@ async def add_image_to_board(
board_id=board_id, image_name=image_name
)
return result
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to add image to board")
@ -53,7 +53,7 @@ async def remove_image_from_board(
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
return result
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@ -79,10 +79,10 @@ async def add_images_to_board(
board_id=board_id, image_name=image_name
)
added_image_names.append(image_name)
except:
except Exception:
pass
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@ -105,8 +105,8 @@ async def remove_images_from_board(
try:
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_image_names.append(image_name)
except:
except Exception:
pass
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@ -37,7 +37,7 @@ async def create_board(
try:
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
return result
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to create board")
@ -50,7 +50,7 @@ async def get_board(
try:
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
return result
except Exception as e:
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
@ -73,7 +73,7 @@ async def update_board(
try:
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
return result
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to update board")
@ -105,7 +105,7 @@ async def delete_board(
deleted_board_images=deleted_board_images,
deleted_images=[],
)
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete board")

View File

@ -55,7 +55,7 @@ async def upload_image(
if crop_visible:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
except:
except Exception:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
@ -73,7 +73,7 @@ async def upload_image(
response.headers["Location"] = image_dto.image_url
return image_dto
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to create image")
@ -85,7 +85,7 @@ async def delete_image(
try:
ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e:
except Exception:
# TODO: Does this need any exception handling at all?
pass
@ -97,7 +97,7 @@ async def clear_intermediates() -> int:
try:
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
return count_deleted
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
pass
@ -115,7 +115,7 @@ async def update_image(
try:
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
except Exception as e:
except Exception:
raise HTTPException(status_code=400, detail="Failed to update image")
@ -131,7 +131,7 @@ async def get_image_dto(
try:
return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e:
except Exception:
raise HTTPException(status_code=404)
@ -147,7 +147,7 @@ async def get_image_metadata(
try:
return ApiDependencies.invoker.services.images.get_metadata(image_name)
except Exception as e:
except Exception:
raise HTTPException(status_code=404)
@ -183,7 +183,7 @@ async def get_image_full(
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception as e:
except Exception:
raise HTTPException(status_code=404)
@ -212,7 +212,7 @@ async def get_image_thumbnail(
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception as e:
except Exception:
raise HTTPException(status_code=404)
@ -234,7 +234,7 @@ async def get_image_urls(
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except Exception as e:
except Exception:
raise HTTPException(status_code=404)
@ -282,10 +282,10 @@ async def delete_images_from_list(
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name)
except:
except Exception:
pass
return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@ -303,10 +303,10 @@ async def star_images_in_list(
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
updated_image_names.append(image_name)
except:
except Exception:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@ -320,8 +320,8 @@ async def unstar_images_in_list(
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
updated_image_names.append(image_name)
except:
except Exception:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception as e:
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")

View File

@ -1,12 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, List, Optional, Union
from typing import Annotated, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic.fields import Field
from ...invocations import *
# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
Edge,

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio
import sys
from inspect import signature
import logging
@ -17,21 +16,11 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path
from pydantic.schema import schema
# This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before
# other invokeai initialization messages
if app_config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
import invokeai.frontend.web as web_dir
import mimetypes
@ -40,12 +29,17 @@ from .api.routers import sessions, models, images, boards, board_images, app_inf
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch
import invokeai.backend.util.hotfixes
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
@ -230,13 +224,16 @@ def invoke_api():
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
l = logging.getLogger(logname)
l.handlers.clear()
log = logging.getLogger(logname)
log.handlers.clear()
for ch in logger.handlers:
l.addHandler(ch)
log.addHandler(ch)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
if app_config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_api()

View File

@ -145,10 +145,10 @@ def set_autocompleter(services: InvocationServices) -> Completer:
completer = Completer(services.model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(True)
except:
except AttributeError:
# pyreadline3 does not have a set_auto_history() method
pass
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")

View File

@ -13,16 +13,8 @@ from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before other invokeai initialization messages
if config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
@ -62,10 +54,15 @@ from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
import torch
import invokeai.backend.util.hotfixes
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
class CliCommand(BaseModel):
@ -482,4 +479,7 @@ def invoke_cli():
if __name__ == "__main__":
if config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_cli()

View File

@ -5,10 +5,10 @@ from typing import Literal
import numpy as np
from pydantic import validator
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
@title("Integer Range")

View File

@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
SDXLConditioningInfo,
)
from ...backend.model_management import ModelPatcher, ModelType
from ...backend.model_management.models import ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent

View File

@ -29,7 +29,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType, ModelType
from ...backend.model_management import BaseModelType
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,

View File

@ -90,7 +90,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
replace_count = (tiles_mask is False).sum()
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]

View File

@ -15,7 +15,7 @@ from diffusers.models.attention_processor import (
)
from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata
@ -32,7 +32,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management import BaseModelType, ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
@ -47,12 +47,10 @@ from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional
from pydantic import BaseModel, Field

View File

@ -16,7 +16,6 @@ from .baseinvocation import (
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,
)

View File

@ -2,14 +2,13 @@
import inspect
import re
from contextlib import ExitStack
# from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import numpy as np
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from tqdm import tqdm
@ -72,7 +71,7 @@ class ONNXPromptInvocation(BaseInvocation):
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras
@ -259,7 +258,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet, ExitStack() as stack:
with unet_info as unet: # , ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)

View File

@ -38,13 +38,10 @@ from easing_functions import (
SineEaseInOut,
SineEaseOut,
)
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from pydantic import BaseModel, Field
from invokeai.app.invocations.primitives import FloatCollectionOutput
from ...backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title

View File

@ -1,9 +1,8 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple, Union
from typing import Literal, Optional, Tuple
import torch
from anyio import Condition
from pydantic import BaseModel, Field
from .baseinvocation import (

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal, Union
from typing import Literal
import cv2 as cv
import numpy as np

View File

@ -1,18 +1,14 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import List, Union, Optional
from typing import Optional
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardRecord,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase

View File

@ -1,15 +1,14 @@
from abc import ABC, abstractmethod
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from abc import ABC, abstractmethod
from typing import Optional, Union, cast
import sqlite3
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import (
BoardRecord,
deserialize_board_record,
)
from pydantic import BaseModel, Field, Extra
@ -230,7 +229,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
f"""--sql
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
@ -241,7 +240,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
f"""--sql
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;

View File

@ -167,7 +167,7 @@ from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig, ListConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
@ -394,7 +394,7 @@ class InvokeAIAppConfig(InvokeAISettings):
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
@ -415,8 +415,8 @@ class InvokeAIAppConfig(InvokeAISettings):
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# fmt: on
@ -438,7 +438,7 @@ class InvokeAIAppConfig(InvokeAISettings):
if conf is None:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except:
except Exception:
pass
InvokeAISettings.initconf = conf
@ -457,7 +457,7 @@ class InvokeAIAppConfig(InvokeAISettings):
"""
if (
cls.singleton_config is None
or type(cls.singleton_config) != cls
or type(cls.singleton_config) is not cls
or (kwargs and cls.singleton_init != kwargs)
):
cls.singleton_config = cls(**kwargs)

View File

@ -9,7 +9,8 @@ import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from ..invocations import *
# Importing * is bad karma but needed here for node detection
from ..invocations import * # noqa: F401 F403
from ..invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -445,7 +446,7 @@ class Graph(BaseModel):
node = graph.nodes[node_id]
# Ensure the node type matches the new node
if type(node) != type(new_node):
if type(node) is not type(new_node):
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
# Ensure the new id is either the same or is not in the graph
@ -632,7 +633,7 @@ class Graph(BaseModel):
[
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
]
) # Get unique types
@ -923,7 +924,7 @@ class GraphExecutionState(BaseModel):
None,
)
if next_node_id == None:
if next_node_id is None:
return None
# Get all parents of the next node

View File

@ -179,7 +179,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
if image_name not in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:

View File

@ -282,7 +282,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._lock.acquire()
self._cursor.execute(
f"""--sql
"""--sql
SELECT images.metadata FROM images
WHERE image_name = ?;
""",
@ -309,7 +309,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
f"""--sql
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
@ -320,7 +320,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
@ -331,7 +331,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
self._cursor.execute(
f"""--sql
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
@ -342,7 +342,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `starred`` state
if changes.starred is not None:
self._cursor.execute(
f"""--sql
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;

View File

@ -1,4 +1,3 @@
import json
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Optional
@ -379,10 +378,10 @@ class ImageService(ImageServiceABC):
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
self._services.logger.error("Failed to delete image record")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
self._services.logger.error("Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
@ -395,10 +394,10 @@ class ImageService(ImageServiceABC):
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image records")
self._services.logger.error("Failed to delete image records")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image files")
self._services.logger.error("Failed to delete image files")
raise
except Exception as e:
self._services.logger.error("Problem deleting image records and files")
@ -412,10 +411,10 @@ class ImageService(ImageServiceABC):
self._services.image_files.delete(image_name)
return count
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image records")
self._services.logger.error("Failed to delete image records")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image files")
self._services.logger.error("Failed to delete image files")
raise
except Exception as e:
self._services.logger.error("Problem deleting image records and files")

View File

@ -7,6 +7,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase

View File

@ -1,7 +1,6 @@
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
"""Utility to collect execution time and GPU usage stats on invocations in flight
"""
Usage:
statistics = InvocationStatsService(graph_execution_manager)

View File

@ -60,7 +60,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if not name in self.__cache:
if name not in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:

View File

@ -1,3 +1,4 @@
from typing import Union
import torch
import numpy as np
import cv2
@ -5,7 +6,7 @@ from PIL import Image
from diffusers.utils import PIL_INTERPOLATION
from einops import rearrange
from controlnet_aux.util import HWC3, resize_image
from controlnet_aux.util import HWC3
###################################################################
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
@ -232,7 +233,8 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
k0 = float(h) / old_h
k1 = float(w) / old_w
safeint = lambda x: int(np.round(x))
def safeint(x: Union[int, float]) -> int:
return int(np.round(x))
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == "fill_resize": # OUTER_FIT

View File

@ -5,7 +5,6 @@ from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management.models import BaseModelType

View File

@ -1,5 +1,5 @@
"""
Initialization file for invokeai.backend
"""
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
from .model_management.models import SilenceWarnings
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo # noqa: F401
from .model_management.models import SilenceWarnings # noqa: F401

View File

@ -1,14 +1,16 @@
"""
Initialization file for invokeai.backend.image_util methods.
"""
from .patchmatch import PatchMatch
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata
from .seamless import configure_model_padding
from .txt2mask import Txt2Mask
from .util import InitImageResizer, make_grid
from .patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding # noqa: F401
from .txt2mask import Txt2Mask # noqa: F401
from .util import InitImageResizer, make_grid # noqa: F401
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
from PIL import ImageDraw
if not debug_status:
return

View File

@ -26,7 +26,7 @@ class PngWriter:
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
existing_name = next(
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
(f for f in dirlist if re.match(r"^(\d+)\..*\.png", f)),
"0000000.0.png",
)
basecount = int(existing_name.split(".", 1)[0]) + 1
@ -98,11 +98,11 @@ class PromptFormatter:
# to do: put model name into the t2i object
# switches.append(f'--model{t2i.model_name}')
if opt.seamless or t2i.seamless:
switches.append(f"--seamless")
switches.append("--seamless")
if opt.init_img:
switches.append(f"-I{opt.init_img}")
if opt.fit:
switches.append(f"--fit")
switches.append("--fit")
if opt.strength and opt.init_img is not None:
switches.append(f"-f{opt.strength or t2i.strength}")
if opt.gfpgan_strength:

View File

@ -52,7 +52,6 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress,
FileBox,
IntTitleSlider,
set_min_terminal_size,
CyclingForm,
MIN_COLS,

View File

@ -116,7 +116,7 @@ class MigrateTo3(object):
appropriate location within the destination models directory.
"""
directories_scanned = set()
for root, dirs, files in os.walk(src_dir):
for root, dirs, files in os.walk(src_dir, followlinks=True):
for d in dirs:
try:
model = Path(root, d)
@ -525,7 +525,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
if version_3: # write into the dest directory
try:
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
except:
except Exception:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / "models").replace(dest_models)

View File

@ -12,7 +12,6 @@ from typing import Optional, List, Dict, Callable, Union, Set
import requests
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import onnx
import torch
from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf

View File

@ -1,10 +1,10 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache
from .lora import ModelPatcher, ONNXModelPatcher
from .models import (
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType # noqa: F401
from .model_cache import ModelCache # noqa: F401
from .lora import ModelPatcher, ONNXModelPatcher # noqa: F401
from .models import ( # noqa: F401
BaseModelType,
ModelType,
SubModelType,
@ -12,5 +12,4 @@ from .models import (
ModelNotFoundException,
DuplicateModelException,
)
from .model_merge import ModelMerger, MergeInterpolationMethod
from .lora import ModelPatcher
from .model_merge import ModelMerger, MergeInterpolationMethod # noqa: F401

View File

@ -5,21 +5,16 @@ from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any, Union, List
from pathlib import Path
import torch
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from onnx import numpy_helper
from onnxruntime import OrtValue
import numpy as np
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from .models.lora import LoRAModel
"""
loras = [
(lora_model1, 0.7),
@ -52,7 +47,7 @@ class ModelPatcher:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except:
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
@ -312,7 +307,8 @@ class TextualInversionManager(BaseTextualInversionManager):
class ONNXModelPatcher:
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
from .models.base import IAIOnnxRuntimeModel
from diffusers import OnnxRuntimeModel
@classmethod
@contextmanager
@ -341,7 +337,7 @@ class ONNXModelPatcher:
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoraModel, float]],
loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
from .models.base import IAIOnnxRuntimeModel

View File

@ -273,7 +273,7 @@ class ModelCache(object):
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
except:
except Exception:
self.cache_entry.unlock()
raise

View File

@ -419,12 +419,12 @@ class ModelManager(object):
base_model_str, model_type_str, model_name = model_key.split("/", 2)
try:
model_type = ModelType(model_type_str)
except:
except Exception:
raise Exception(f"Unknown model type: {model_type_str}")
try:
base_model = BaseModelType(base_model_str)
except:
except Exception:
raise Exception(f"Unknown base model: {base_model_str}")
return (model_name, base_model, model_type)
@ -855,7 +855,7 @@ class ModelManager(object):
info.pop("config")
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
except:
except Exception:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
raise
@ -1042,7 +1042,7 @@ class ModelManager(object):
# Patch in the SD VAE from core so that it is available for use by the UI
try:
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
except:
except Exception:
pass
installer = ModelInstall(

View File

@ -431,7 +431,7 @@ class PipelineFolderProbe(FolderProbeBase):
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except:
except Exception:
pass
return ModelVariantType.Normal

View File

@ -56,7 +56,7 @@ class ModelSearch(ABC):
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import (
from .base import ( # noqa: F401
BaseModelType,
ModelType,
SubModelType,
@ -118,7 +118,7 @@ def get_model_config_enums():
fields = model_config.__annotations__
try:
field = fields["model_format"]
except:
except Exception:
raise Exception("format field not found")
# model_format: None

View File

@ -3,27 +3,28 @@ import os
import sys
import typing
import inspect
from enum import Enum
import warnings
from abc import ABCMeta, abstractmethod
from contextlib import suppress
from enum import Enum
from pathlib import Path
from picklescan.scanner import scan_file_path
import torch
import numpy as np
import safetensors.torch
from pathlib import Path
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
import onnx
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from onnx import numpy_helper
from onnxruntime import (
InferenceSession,
SessionOptions,
get_available_providers,
)
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class DuplicateModelException(Exception):
@ -171,7 +172,7 @@ class ModelBase(metaclass=ABCMeta):
fields = value.__annotations__
try:
field = fields["model_format"]
except:
except Exception:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
@ -244,7 +245,7 @@ class DiffusersModel(ModelBase):
try:
config_data = DiffusionPipeline.load_config(self.model_path)
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
except Exception:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
@ -343,7 +344,7 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
except Exception:
pass
# calculate files size if there is no index file
@ -440,7 +441,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except:
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
@ -452,11 +453,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
return checkpoint
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
@ -639,7 +635,7 @@ class IAIOnnxRuntimeModel:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
output_names = self.session.get_outputs()
# output_names = self.session.get_outputs()
# for k in inputs:
# self.io_binding.bind_cpu_input(k, inputs[k])
# for name in output_names:

View File

@ -43,7 +43,7 @@ class ControlNetModel(ModelBase):
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
# config = json.loads(os.path.join(self.model_path, "config.json"))
except:
except Exception:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None)
@ -53,7 +53,7 @@ class ControlNetModel(ModelBase):
try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
except Exception:
raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None):
@ -78,7 +78,7 @@ class ControlNetModel(ModelBase):
variant=variant,
)
break
except:
except Exception:
pass
if not model:
raise ModelNotFoundException()

View File

@ -330,5 +330,5 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
except Exception:
return None

View File

@ -1,25 +1,17 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from typing import Literal
from diffusers import OnnxRuntimeModel
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
OnnxRuntimeModel,
IAIOnnxRuntimeModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
class StableDiffusionOnnxModelFormat(str, Enum):

View File

@ -44,14 +44,14 @@ class VaeModel(ModelBase):
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
# config = json.loads(os.path.join(self.model_path, "config.json"))
except:
except Exception:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
except Exception:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SubModelType] = None):

View File

@ -1,11 +1,15 @@
"""
Initialization file for the invokeai.backend.stable_diffusion package
"""
from .diffusers_pipeline import (
from .diffusers_pipeline import ( # noqa: F401
ConditioningData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from .diffusion import InvokeAIDiffuserComponent
from .diffusion.cross_attention_map_saving import AttentionMapSaver
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
PostprocessingSettings,
BasicConditioningInfo,
SDXLConditioningInfo,
)

View File

@ -2,10 +2,8 @@ from __future__ import annotations
import dataclasses
import inspect
import math
import secrets
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, Union
from typing import Any, Callable, List, Optional, Union
import PIL.Image
import einops

View File

@ -1,9 +1,9 @@
"""
Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin
from .cross_attention_map_saving import AttentionMapSaver
from .shared_invokeai_diffusion import (
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import ( # noqa: F401
InvokeAIDiffuserComponent,
PostprocessingSettings,
BasicConditioningInfo,

View File

@ -4,6 +4,7 @@
import enum
import math
from dataclasses import dataclass, field
from typing import Callable, Optional
import diffusers
@ -12,6 +13,11 @@ import torch
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
SlicedAttnProcessor,
)
from torch import nn
import invokeai.backend.util.logging as logger
@ -522,14 +528,6 @@ class AttnProcessor:
return hidden_states
"""
from dataclasses import dataclass, field
import torch
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
SlicedAttnProcessor,
)
@dataclass

View File

@ -5,8 +5,6 @@ import torch
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional import resize as tv_resize
from .cross_attention_control import CrossAttentionType, get_cross_attention_modules
class AttentionMapSaver:
def __init__(self, token_ids: range, latents_shape: torch.Size):

View File

@ -3,15 +3,12 @@ from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import math
from typing import Any, Callable, Dict, Optional, Union, List
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .cross_attention_control import (
@ -579,7 +576,7 @@ class InvokeAIDiffuserComponent:
latents.to(device="cpu")
if (
h_symmetry_time_pct != None
h_symmetry_time_pct is not None
and self.last_percent_through < h_symmetry_time_pct
and percent_through >= h_symmetry_time_pct
):
@ -595,7 +592,7 @@ class InvokeAIDiffuserComponent:
)
if (
v_symmetry_time_pct != None
v_symmetry_time_pct is not None
and self.last_percent_through < v_symmetry_time_pct
and percent_through >= v_symmetry_time_pct
):

View File

@ -1,6 +1,6 @@
from ldm.modules.image_degradation.bsrgan import (
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
degradation_bsrgan_variant as degradation_fn_bsr,
)
from ldm.modules.image_degradation.bsrgan_light import (
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@ -573,14 +573,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
jpeg_prob, scale2_prob = 0.9, 0.25
# isp_prob = 0.25 # uncomment with `if i== 6` block below
# sf_ori = sf # uncomment with `if i== 6` block below
h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
# hq = image.copy() # uncomment with `if i== 6` block below
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
@ -777,7 +778,7 @@ if __name__ == "__main__":
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
# print(img_hq.shape)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
@ -788,5 +789,6 @@ if __name__ == "__main__":
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
util.imsave(img_concat, str(i) + ".png")

View File

@ -577,14 +577,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
jpeg_prob, scale2_prob = 0.9, 0.25
# isp_prob = 0.25 # uncomment with `if i== 6` block below
# sf_ori = sf # uncomment with `if i== 6` block below
h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
# hq = image.copy() # uncomment with `if i== 6` block below
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:

View File

@ -8,8 +8,6 @@ import numpy as np
import torch
from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
import invokeai.backend.util.logging as logger
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@ -50,6 +48,8 @@ def get_timestamp():
def imshow(x, title=None, cbar=False, figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
if title:
@ -60,6 +60,8 @@ def imshow(x, title=None, cbar=False, figsize=None):
def surf(Z, cmap="rainbow", figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
ax3 = plt.axes(projection="3d")
@ -89,7 +91,7 @@ def get_image_paths(dataroot):
def _get_paths_from_images(path):
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)

View File

@ -1 +1 @@
from .schedulers import SCHEDULER_MAP
from .schedulers import SCHEDULER_MAP # noqa: F401

View File

@ -1,4 +1,4 @@
"""
Initialization file for invokeai.backend.training
"""
from .textual_inversion_training import do_textual_inversion_training, parse_args
from .textual_inversion_training import do_textual_inversion_training, parse_args # noqa: F401

View File

@ -1,7 +1,7 @@
"""
Initialization file for invokeai.backend.util
"""
from .devices import (
from .devices import ( # noqa: F401
CPU_DEVICE,
CUDA_DEVICE,
MPS_DEVICE,
@ -10,5 +10,5 @@ from .devices import (
normalize_device,
torch_dtype,
)
from .log import write_log
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir
from .log import write_log # noqa: F401
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401

View File

@ -25,10 +25,15 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from invokeai.backend.util.logging import InvokeAILogger
# TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added
logger = InvokeAILogger.getLogger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
"""
A ControlNet model.
@ -111,7 +116,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
"DownBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,

View File

@ -27,8 +27,8 @@ def write_log_message(results, output_cntr):
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
if len(log_lines) > 1:
subcntr = 1
for l in log_lines:
print(f"[{output_cntr}.{subcntr}] {l}", end="")
for ll in log_lines:
print(f"[{output_cntr}.{subcntr}] {ll}", end="")
subcntr += 1
else:
print(f"[{output_cntr}] {log_lines[0]}", end="")

View File

@ -182,13 +182,13 @@ import urllib.parse
from abc import abstractmethod
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
try:
import syslog
SYSLOG_AVAILABLE = True
except:
except ImportError:
SYSLOG_AVAILABLE = False
@ -417,7 +417,7 @@ class InvokeAILogger(object):
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
else:
syslog_args["address"] = arg_name
except:
except Exception:
raise ValueError(f"{args} is not a value argument list for syslog logging")
return logging.handlers.SysLogHandler(**syslog_args)

View File

@ -191,7 +191,7 @@ class ChunkedSlicedAttnProcessor:
assert value.shape[0] == 1
assert hidden_states.shape[0] == 1
dtype = query.dtype
# dtype = query.dtype
if attn.upcast_attention:
query = query.float()
key = key.float()

View File

@ -84,7 +84,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config, **kwargs):
if not "target" in config:
if "target" not in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
@ -234,7 +234,8 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (
def dot(grad, shift):
return (
torch.stack(
(
grid[: shape[0], : shape[1], 0] + shift[0],
@ -287,7 +288,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
if dest.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except:
except AttributeError:
file_name = os.path.basename(url)
dest = dest / file_name
else:
@ -342,7 +343,7 @@ def url_attachment_name(url: str) -> dict:
resp = requests.get(url, stream=True)
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
return match.group(1)
except:
except Exception:
return None

View File

@ -1,4 +1,4 @@
"""
Initialization file for invokeai.frontend.CLI
"""
from .CLI import main as invokeai_command_line_interface
from .CLI import main as invokeai_command_line_interface # noqa: F401

View File

@ -1,4 +1,4 @@
"""
Wrapper for invokeai.backend.configure.invokeai_configure
"""
from ...backend.install.invokeai_configure import main as invokeai_configure
from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401

View File

@ -80,7 +80,7 @@ def welcome(versions: dict):
def get_extras():
extras = ""
try:
dist = pkg_resources.get_distribution("xformers")
_ = pkg_resources.get_distribution("xformers")
extras = "[xformers]"
except pkg_resources.DistributionNotFound:
pass
@ -90,7 +90,7 @@ def get_extras():
def main():
versions = get_versions()
if invokeai_is_running():
print(f":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
input("Press any key to continue...")
return
@ -122,9 +122,9 @@ def main():
print("")
print("")
if os.system(cmd) == 0:
print(f":heavy_check_mark: Upgrade successful")
print(":heavy_check_mark: Upgrade successful")
else:
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
print(":exclamation: [bold red]Upgrade failed[/red bold]")
if __name__ == "__main__":

View File

@ -251,7 +251,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets"""
widgets = dict()
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and not x in exclude]
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models) == 0
@ -357,14 +357,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
try:
v.hidden = True
v.editable = False
except:
except Exception:
pass
for k, v in widgets[selected_tab].items():
try:
v.hidden = False
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
v.editable = True
except:
except Exception:
pass
self.__class__.current_tab = selected_tab # for persistence
self.display()
@ -541,7 +541,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.ti_models,
]
for section in ui_sections:
if not "models_selected" in section:
if "models_selected" not in section:
continue
selected = set([section["models"][x] for x in section["models_selected"].value])
models_to_install = [x for x in selected if not self.all_models[x].installed]
@ -637,7 +637,7 @@ def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPre
return None
else:
return response
except:
except Exception:
return None
@ -673,8 +673,7 @@ def process_and_execute(
def select_and_download_models(opt: Namespace):
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
config.precision = precision
helper = lambda x: ask_user_for_prediction_type(x)
installer = ModelInstall(config, prediction_type_helper=helper)
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
if opt.list_models:
installer.list_models(opt.list_models)
elif opt.add or opt.delete:

View File

@ -102,8 +102,8 @@ def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
class IntSlider(npyscreen.Slider):
def translate_value(self):
stri = "%2d / %2d" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(l)
length = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(length)
return stri
@ -167,8 +167,8 @@ class FloatSlider(npyscreen.Slider):
# this is supposed to adjust display precision, but doesn't
def translate_value(self):
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(l)
length = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(length)
return stri

View File

@ -1,4 +1,3 @@
import os
import sys
import argparse

View File

@ -1,4 +1,4 @@
"""
Initialization file for invokeai.frontend.merge
"""
from .merge_diffusers import main as invokeai_merge_diffusers
from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401

View File

@ -9,19 +9,15 @@ import curses
import sys
from argparse import Namespace
from pathlib import Path
from typing import List, Union
from typing import List, Optional
import npyscreen
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import (
ModelMerger,
MergeInterpolationMethod,
ModelManager,
ModelType,
BaseModelType,
@ -318,7 +314,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self, base_model: BaseModelType = None) -> List[str]:
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
model_names = [
info["model_name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)

View File

@ -1,4 +1,4 @@
"""
Initialization file for invokeai.frontend.training
"""
from .textual_inversion import main as invokeai_textual_inversion
from .textual_inversion import main as invokeai_textual_inversion # noqa: F401

View File

@ -59,7 +59,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
try:
default = self.model_names.index(saved_args["model"])
except:
except Exception:
pass
self.add_widget_intelligent(
@ -377,7 +377,7 @@ def previous_args() -> dict:
try:
conf = OmegaConf.load(conf_file)
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
except:
except Exception:
conf = None
return conf

View File

@ -4,7 +4,7 @@ import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { t } from 'i18next';
import { startCase } from 'lodash-es';
import { startCase, upperFirst } from 'lodash-es';
import { LogLevelName } from 'roarr';
import {
isAnySessionRejected,
@ -26,6 +26,7 @@ import {
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../util/makeToast';
import { LANGUAGES } from './constants';
import { zPydanticValidationError } from './zodSchemas';
export type CancelStrategy = 'immediate' | 'scheduled';
@ -361,9 +362,24 @@ export const systemSlice = createSlice({
state.progressImage = null;
let errorDescription = undefined;
const duration = 5000;
if (action.payload?.status === 422) {
errorDescription = 'Validation Error';
const result = zPydanticValidationError.safeParse(action.payload);
if (result.success) {
result.data.error.detail.map((e) => {
state.toastQueue.push(
makeToast({
title: upperFirst(e.msg),
status: 'error',
description: `Path:
${e.loc.slice(3).join('.')}`,
duration,
})
);
});
return;
}
} else if (action.payload?.error) {
errorDescription = action.payload?.error as string;
}
@ -373,6 +389,7 @@ export const systemSlice = createSlice({
title: t('toast.serverError'),
status: 'error',
description: errorDescription,
duration,
})
);
});

View File

@ -0,0 +1,14 @@
import { z } from 'zod';
export const zPydanticValidationError = z.object({
status: z.literal(422),
error: z.object({
detail: z.array(
z.object({
loc: z.array(z.string()),
msg: z.string(),
type: z.string(),
})
),
}),
});

View File

@ -1,7 +1,7 @@
"""
initialization file for invokeai
"""
from .invokeai_version import __version__
from .invokeai_version import __version__ # noqa: F401
__app_id__ = "invoke-ai/InvokeAI"
__app_name__ = "InvokeAI"

View File

@ -8,9 +8,8 @@ from google.colab import files
from IPython.display import Image as ipyimg
import ipywidgets as widgets
from PIL import Image
from numpy import asarray
from einops import rearrange, repeat
import torch, torchvision
import torchvision
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
import time
@ -68,14 +67,14 @@ def get_custom_cond(mode):
elif mode == "text_conditional":
w = widgets.Text(value="A cake with cream!", disabled=True)
display(w)
display(w) # noqa: F821
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
f.write(w.value)
elif mode == "class_conditional":
w = widgets.IntSlider(min=0, max=1000)
display(w)
display(w) # noqa: F821
with open(f"{dest}/{mode}/custom.txt", "w") as f:
f.write(w.value)
@ -96,7 +95,7 @@ def select_cond_path(mode):
onlyfiles = [f for f in sorted(os.listdir(path))]
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
display(selected)
display(selected) # noqa: F821
selected_path = os.path.join(path, selected.value)
return selected_path
@ -123,7 +122,7 @@ def get_cond(mode, selected_path):
def visualize_cond_img(path):
display(ipyimg(filename=path))
display(ipyimg(filename=path)) # noqa: F821
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
@ -331,7 +330,7 @@ def make_convolutional_sample(
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
except Exception:
pass
log["sample"] = x_sample

View File

@ -95,7 +95,14 @@ dependencies = [
"dev" = [
"pudb",
]
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
"test" = [
"black",
"flake8",
"Flake8-pyproject",
"pytest>6.0.0",
"pytest-cov",
"pytest-datadir",
]
"xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'",
@ -185,6 +192,8 @@ output = "coverage/index.xml"
[tool.flake8]
max-line-length = 120
ignore = ["E203", "E266", "E501", "W503"]
select = ["B", "C", "E", "F", "W", "T4"]
[tool.black]
line-length = 120

View File

@ -4,7 +4,6 @@ Read a checkpoint/safetensors file and write out a template .json file containin
its metadata for use in fast model probing.
"""
import sys
import argparse
import json

View File

@ -3,11 +3,12 @@
import warnings
from invokeai.app.cli_app import invoke_cli
warnings.warn(
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
DeprecationWarning,
)
from invokeai.app.cli_app import invoke_cli
invoke_cli()

View File

@ -2,7 +2,7 @@
"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
import sys
from PIL import Image, PngImagePlugin
from PIL import Image
if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")

View File

@ -2,13 +2,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
import logging
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
import os
import sys
def main():
# Change working directory to the repo root

View File

@ -2,13 +2,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
import logging
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
import os
import sys
def main():
# Change working directory to the repo root

View File

@ -1,6 +1,7 @@
"""make variations of input image"""
import argparse, os, sys, glob
import argparse
import os
import PIL
import torch
import numpy as np
@ -12,7 +13,6 @@ from einops import rearrange, repeat
from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
import time
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
@ -234,7 +234,6 @@ def main():
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
@ -279,8 +278,6 @@ def main():
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")

View File

@ -1,4 +1,6 @@
import argparse, os, sys, glob
import argparse
import glob
import os
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm

View File

@ -1,13 +1,13 @@
import argparse, os, sys, glob
import clip
import argparse
import glob
import os
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from einops import rearrange
from torchvision.utils import make_grid
import scann
import time
@ -390,8 +390,8 @@ if __name__ == "__main__":
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_np = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid_np.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")

View File

@ -1,24 +1,24 @@
import argparse, os, sys, datetime, glob, importlib, csv
import argparse
import datetime
import glob
import os
import sys
import numpy as np
import time
import torch
import torchvision
import pytorch_lightning as pl
from packaging import version
from omegaconf import OmegaConf
from torch.utils.data import random_split, DataLoader, Dataset, Subset
from torch.utils.data import DataLoader, Dataset
from functools import partial
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import (
ModelCheckpoint,
Callback,
LearningRateMonitor,
)
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
@ -651,7 +651,7 @@ if __name__ == "__main__":
trainer_config["accelerator"] = "auto"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not "gpus" in trainer_config:
if "gpus" not in trainer_config:
del trainer_config["accelerator"]
cpu = True
else:
@ -803,7 +803,7 @@ if __name__ == "__main__":
trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###
trainer.logdir = logdir
# data
config.data.params.train.params.data_root = opt.data_root

View File

@ -2,7 +2,7 @@ from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
from ldm.modules.embedding_manager import EmbeddingManager
from ldm.invoke.globals import Globals
import argparse, os
import argparse
from functools import partial
import torch
@ -108,7 +108,7 @@ if __name__ == "__main__":
manager.load(manager_ckpt)
for placeholder_string in manager.string_to_token_dict:
if not placeholder_string in string_to_token_dict:
if placeholder_string not in string_to_token_dict:
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]

View File

@ -1,6 +1,12 @@
import argparse, os, sys, glob, datetime, yaml
import torch
import argparse
import datetime
import glob
import os
import sys
import time
import yaml
import torch
import numpy as np
from tqdm import trange
@ -10,7 +16,9 @@ from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.0) / 2.0
def rescale(x: float) -> float:
return (x + 1.0) / 2.0
def custom_to_pil(x):
@ -45,7 +53,7 @@ def logs2pil(logs, keys=["sample"]):
else:
print(f"Unknown format for key {k}. ")
img = None
except:
except Exception:
img = None
imgs[k] = img
return imgs

View File

@ -1,4 +1,5 @@
import os, sys
import os
import sys
import numpy as np
import scann
import argparse

View File

@ -1,4 +1,5 @@
import argparse, os, sys, glob
import argparse
import os
import torch
import numpy as np
from omegaconf import OmegaConf
@ -7,10 +8,9 @@ from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from contextlib import nullcontext
import k_diffusion as K
import torch.nn as nn
@ -251,7 +251,6 @@ def main():
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
@ -310,8 +309,6 @@ def main():
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")

Some files were not shown because too many files have changed in this diff Show More