mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/inpaint_gen
This commit is contained in:
commit
3c43594c26
8
.github/workflows/style-checks.yml
vendored
8
.github/workflows/style-checks.yml
vendored
@ -1,6 +1,6 @@
|
|||||||
name: style checks
|
name: style checks
|
||||||
# just formatting for now
|
# just formatting and flake8 for now
|
||||||
# TODO: add isort and flake8 later
|
# TODO: add isort later
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
@ -20,8 +20,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies with pip
|
- name: Install dependencies with pip
|
||||||
run: |
|
run: |
|
||||||
pip install black
|
pip install black flake8 Flake8-pyproject
|
||||||
|
|
||||||
# - run: isort --check-only .
|
# - run: isort --check-only .
|
||||||
- run: black --check .
|
- run: black --check .
|
||||||
# - run: flake8
|
- run: flake8
|
||||||
|
@ -8,3 +8,10 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
entry: black
|
entry: black
|
||||||
types: [python]
|
types: [python]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: flake8
|
||||||
|
stages: [commit]
|
||||||
|
language: system
|
||||||
|
entry: flake8
|
||||||
|
types: [python]
|
||||||
|
@ -407,7 +407,7 @@ def get_pip_from_venv(venv_path: Path) -> str:
|
|||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pip = "Scripts\pip.exe" if OS == "Windows" else "bin/pip"
|
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
||||||
return str(venv_path.expanduser().resolve() / pip)
|
return str(venv_path.expanduser().resolve() / pip)
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
inst.install(**args.__dict__)
|
inst.install(**args.__dict__)
|
||||||
except KeyboardInterrupt as exc:
|
except KeyboardInterrupt:
|
||||||
print("\n")
|
print("\n")
|
||||||
print("Ctrl-C pressed. Aborting.")
|
print("Ctrl-C pressed. Aborting.")
|
||||||
print("Come back soon!")
|
print("Come back soon!")
|
||||||
|
@ -70,7 +70,7 @@ def confirm_install(dest: Path) -> bool:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"InvokeAI will be installed in {dest}")
|
print(f"InvokeAI will be installed in {dest}")
|
||||||
dest_confirmed = not Confirm.ask(f"Would you like to pick a different location?", default=False)
|
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
||||||
console.line()
|
console.line()
|
||||||
|
|
||||||
return dest_confirmed
|
return dest_confirmed
|
||||||
@ -90,7 +90,7 @@ def dest_path(dest=None) -> Path:
|
|||||||
dest = Path(dest).expanduser().resolve()
|
dest = Path(dest).expanduser().resolve()
|
||||||
else:
|
else:
|
||||||
dest = Path.cwd().expanduser().resolve()
|
dest = Path.cwd().expanduser().resolve()
|
||||||
prev_dest = dest.expanduser().resolve()
|
prev_dest = init_path = dest
|
||||||
|
|
||||||
dest_confirmed = confirm_install(dest)
|
dest_confirmed = confirm_install(dest)
|
||||||
|
|
||||||
@ -109,9 +109,9 @@ def dest_path(dest=None) -> Path:
|
|||||||
)
|
)
|
||||||
|
|
||||||
console.line()
|
console.line()
|
||||||
print(f"[orange3]Please select the destination directory for the installation:[/] \[{browse_start}]: ")
|
console.print(f"[orange3]Please select the destination directory for the installation:[/] \\[{browse_start}]: ")
|
||||||
selected = prompt(
|
selected = prompt(
|
||||||
f">>> ",
|
">>> ",
|
||||||
complete_in_thread=True,
|
complete_in_thread=True,
|
||||||
completer=path_completer,
|
completer=path_completer,
|
||||||
default=str(browse_start) + os.sep,
|
default=str(browse_start) + os.sep,
|
||||||
@ -134,14 +134,14 @@ def dest_path(dest=None) -> Path:
|
|||||||
try:
|
try:
|
||||||
dest.mkdir(exist_ok=True, parents=True)
|
dest.mkdir(exist_ok=True, parents=True)
|
||||||
return dest
|
return dest
|
||||||
except PermissionError as exc:
|
except PermissionError:
|
||||||
print(
|
console.print(
|
||||||
f"Failed to create directory {dest} due to insufficient permissions",
|
f"Failed to create directory {dest} due to insufficient permissions",
|
||||||
style=Style(color="red"),
|
style=Style(color="red"),
|
||||||
highlight=True,
|
highlight=True,
|
||||||
)
|
)
|
||||||
except OSError as exc:
|
except OSError:
|
||||||
console.print_exception(exc)
|
console.print_exception()
|
||||||
|
|
||||||
if Confirm.ask("Would you like to try again?"):
|
if Confirm.ask("Would you like to try again?"):
|
||||||
dest_path(init_path)
|
dest_path(init_path)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
@ -45,7 +44,7 @@ def check_internet() -> bool:
|
|||||||
try:
|
try:
|
||||||
urllib.request.urlopen(host, timeout=1)
|
urllib.request.urlopen(host, timeout=1)
|
||||||
return True
|
return True
|
||||||
except:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ async def add_image_to_board(
|
|||||||
board_id=board_id, image_name=image_name
|
board_id=board_id, image_name=image_name
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||||
|
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ async def remove_image_from_board(
|
|||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||||
|
|
||||||
|
|
||||||
@ -79,10 +79,10 @@ async def add_images_to_board(
|
|||||||
board_id=board_id, image_name=image_name
|
board_id=board_id, image_name=image_name
|
||||||
)
|
)
|
||||||
added_image_names.append(image_name)
|
added_image_names.append(image_name)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||||
|
|
||||||
|
|
||||||
@ -105,8 +105,8 @@ async def remove_images_from_board(
|
|||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||||
removed_image_names.append(image_name)
|
removed_image_names.append(image_name)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||||
|
@ -37,7 +37,7 @@ async def create_board(
|
|||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ async def get_board(
|
|||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=404, detail="Board not found")
|
raise HTTPException(status_code=404, detail="Board not found")
|
||||||
|
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ async def update_board(
|
|||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ async def delete_board(
|
|||||||
deleted_board_images=deleted_board_images,
|
deleted_board_images=deleted_board_images,
|
||||||
deleted_images=[],
|
deleted_images=[],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete board")
|
raise HTTPException(status_code=500, detail="Failed to delete board")
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ async def upload_image(
|
|||||||
if crop_visible:
|
if crop_visible:
|
||||||
bbox = pil_image.getbbox()
|
bbox = pil_image.getbbox()
|
||||||
pil_image = pil_image.crop(bbox)
|
pil_image = pil_image.crop(bbox)
|
||||||
except:
|
except Exception:
|
||||||
# Error opening the image
|
# Error opening the image
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ async def upload_image(
|
|||||||
response.headers["Location"] = image_dto.image_url
|
response.headers["Location"] = image_dto.image_url
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ async def delete_image(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.images.delete(image_name)
|
ApiDependencies.invoker.services.images.delete(image_name)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# TODO: Does this need any exception handling at all?
|
# TODO: Does this need any exception handling at all?
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ async def clear_intermediates() -> int:
|
|||||||
try:
|
try:
|
||||||
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
|
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
|
||||||
return count_deleted
|
return count_deleted
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
|
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ async def update_image(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ async def get_image_dto(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ async def get_image_metadata(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ async def get_image_full(
|
|||||||
)
|
)
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ async def get_image_thumbnail(
|
|||||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@ -234,7 +234,7 @@ async def get_image_urls(
|
|||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@ -282,10 +282,10 @@ async def delete_images_from_list(
|
|||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.images.delete(image_name)
|
ApiDependencies.invoker.services.images.delete(image_name)
|
||||||
deleted_images.append(image_name)
|
deleted_images.append(image_name)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||||
|
|
||||||
|
|
||||||
@ -303,10 +303,10 @@ async def star_images_in_list(
|
|||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||||
updated_image_names.append(image_name)
|
updated_image_names.append(image_name)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||||
|
|
||||||
|
|
||||||
@ -320,8 +320,8 @@ async def unstar_images_in_list(
|
|||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||||
updated_image_names.append(image_name)
|
updated_image_names.append(image_name)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Response
|
from fastapi import Body, HTTPException, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from ...invocations import *
|
# Importing * is bad karma but needed here for node detection
|
||||||
|
from ...invocations import * # noqa: F401 F403
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.graph import (
|
from ...services.graph import (
|
||||||
Edge,
|
Edge,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -17,21 +16,11 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
# This should come early so that modules can log their initialization properly
|
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
app_config.parse_args()
|
|
||||||
logger = InvokeAILogger.getLogger(config=app_config)
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
# we call this early so that the message appears before
|
|
||||||
# other invokeai initialization messages
|
|
||||||
if app_config.version:
|
|
||||||
print(f"InvokeAI version {__version__}")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
|
||||||
@ -40,12 +29,17 @@ from .api.routers import sessions, models, images, boards, board_images, app_inf
|
|||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.hotfixes
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
app_config.parse_args()
|
||||||
|
logger = InvokeAILogger.getLogger(config=app_config)
|
||||||
|
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||||
@ -230,13 +224,16 @@ def invoke_api():
|
|||||||
|
|
||||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||||
for logname in ["uvicorn.access", "uvicorn"]:
|
for logname in ["uvicorn.access", "uvicorn"]:
|
||||||
l = logging.getLogger(logname)
|
log = logging.getLogger(logname)
|
||||||
l.handlers.clear()
|
log.handlers.clear()
|
||||||
for ch in logger.handlers:
|
for ch in logger.handlers:
|
||||||
l.addHandler(ch)
|
log.addHandler(ch)
|
||||||
|
|
||||||
loop.run_until_complete(server.serve())
|
loop.run_until_complete(server.serve())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
if app_config.version:
|
||||||
|
print(f"InvokeAI version {__version__}")
|
||||||
|
else:
|
||||||
invoke_api()
|
invoke_api()
|
||||||
|
@ -145,10 +145,10 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
|||||||
completer = Completer(services.model_manager)
|
completer = Completer(services.model_manager)
|
||||||
|
|
||||||
readline.set_completer(completer.complete)
|
readline.set_completer(completer.complete)
|
||||||
# pyreadline3 does not have a set_auto_history() method
|
|
||||||
try:
|
try:
|
||||||
readline.set_auto_history(True)
|
readline.set_auto_history(True)
|
||||||
except:
|
except AttributeError:
|
||||||
|
# pyreadline3 does not have a set_auto_history() method
|
||||||
pass
|
pass
|
||||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||||
readline.set_completer_delims(" ")
|
readline.set_completer_delims(" ")
|
||||||
|
@ -13,16 +13,8 @@ from pydantic.fields import Field
|
|||||||
# This should come early so that the logger can pick up its configuration options
|
# This should come early so that the logger can pick up its configuration options
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
config.parse_args()
|
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
# we call this early so that the message appears before other invokeai initialization messages
|
|
||||||
if config.version:
|
|
||||||
print(f"InvokeAI version {__version__}")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
@ -62,10 +54,15 @@ from .services.processor import DefaultInvocationProcessor
|
|||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.hotfixes
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
|
|
||||||
class CliCommand(BaseModel):
|
class CliCommand(BaseModel):
|
||||||
@ -482,4 +479,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
if config.version:
|
||||||
|
print(f"InvokeAI version {__version__}")
|
||||||
|
else:
|
||||||
invoke_cli()
|
invoke_cli()
|
||||||
|
@ -5,10 +5,10 @@ from typing import Literal
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
|
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||||
|
|
||||||
|
|
||||||
@title("Integer Range")
|
@title("Integer Range")
|
||||||
|
@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
|||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...backend.model_management import ModelPatcher, ModelType
|
from ...backend.model_management.models import ModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
|
@ -29,7 +29,7 @@ from pydantic import BaseModel, Field, validator
|
|||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType
|
from ...backend.model_management import BaseModelType
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
|
@ -90,7 +90,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
# Find all invalid tiles and replace with a random valid tile
|
# Find all invalid tiles and replace with a random valid tile
|
||||||
replace_count = (tiles_mask == False).sum()
|
replace_count = (tiles_mask is False).sum()
|
||||||
rng = np.random.default_rng(seed=seed)
|
rng = np.random.default_rng(seed=seed)
|
||||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from diffusers.models.attention_processor import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
@ -32,7 +32,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
|||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
from ...backend.model_management.models import BaseModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
@ -47,12 +47,10 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
|||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
|
||||||
UIType,
|
UIType,
|
||||||
tags,
|
tags,
|
||||||
title,
|
title,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@ from .baseinvocation import (
|
|||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
|
||||||
tags,
|
tags,
|
||||||
title,
|
title,
|
||||||
)
|
)
|
||||||
|
@ -2,14 +2,13 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from contextlib import ExitStack
|
|
||||||
|
# from contextlib import ExitStack
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -72,7 +71,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
@ -259,7 +258,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
|
|
||||||
with unet_info as unet, ExitStack() as stack:
|
with unet_info as unet: # , ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
|
@ -38,13 +38,10 @@ from easing_functions import (
|
|||||||
SineEaseInOut,
|
SineEaseInOut,
|
||||||
SineEaseOut,
|
SineEaseOut,
|
||||||
)
|
)
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from matplotlib.ticker import MaxNLocator
|
from matplotlib.ticker import MaxNLocator
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||||
|
|
||||||
from ...backend.util.logging import InvokeAILogger
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal, Optional, Tuple, Union
|
from typing import Literal, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from anyio import Condition
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Union
|
from typing import Literal
|
||||||
|
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -1,18 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import List, Union, Optional
|
from typing import Optional
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
from invokeai.app.services.board_record_storage import (
|
from invokeai.app.services.board_record_storage import (
|
||||||
BoardRecord,
|
BoardRecord,
|
||||||
BoardRecordStorageBase,
|
BoardRecordStorageBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.app.services.image_record_storage import (
|
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||||
ImageRecordStorageBase,
|
|
||||||
OffsetPaginatedResults,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional, cast
|
|
||||||
import sqlite3
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional, Union
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.models.board_record import (
|
from invokeai.app.services.models.board_record import (
|
||||||
BoardRecord,
|
BoardRecord,
|
||||||
deserialize_board_record,
|
deserialize_board_record,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, Extra
|
from pydantic import BaseModel, Field, Extra
|
||||||
|
|
||||||
|
|
||||||
@ -230,7 +229,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
# Change the name of a board
|
# Change the name of a board
|
||||||
if changes.board_name is not None:
|
if changes.board_name is not None:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
UPDATE boards
|
UPDATE boards
|
||||||
SET board_name = ?
|
SET board_name = ?
|
||||||
WHERE board_id = ?;
|
WHERE board_id = ?;
|
||||||
@ -241,7 +240,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
# Change the cover image of a board
|
# Change the cover image of a board
|
||||||
if changes.cover_image_name is not None:
|
if changes.cover_image_name is not None:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
UPDATE boards
|
UPDATE boards
|
||||||
SET cover_image_name = ?
|
SET cover_image_name = ?
|
||||||
WHERE board_id = ?;
|
WHERE board_id = ?;
|
||||||
|
@ -167,7 +167,7 @@ from argparse import ArgumentParser
|
|||||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import BaseSettings, Field, parse_obj_as
|
from pydantic import BaseSettings, Field, parse_obj_as
|
||||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
||||||
@ -394,7 +394,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
|
precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||||
@ -415,8 +415,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
|
|
||||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -438,7 +438,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
if conf is None:
|
if conf is None:
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
InvokeAISettings.initconf = conf
|
InvokeAISettings.initconf = conf
|
||||||
|
|
||||||
@ -457,7 +457,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
cls.singleton_config is None
|
cls.singleton_config is None
|
||||||
or type(cls.singleton_config) != cls
|
or type(cls.singleton_config) is not cls
|
||||||
or (kwargs and cls.singleton_init != kwargs)
|
or (kwargs and cls.singleton_init != kwargs)
|
||||||
):
|
):
|
||||||
cls.singleton_config = cls(**kwargs)
|
cls.singleton_config = cls(**kwargs)
|
||||||
|
@ -9,7 +9,8 @@ import networkx as nx
|
|||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from ..invocations import *
|
# Importing * is bad karma but needed here for node detection
|
||||||
|
from ..invocations import * # noqa: F401 F403
|
||||||
from ..invocations.baseinvocation import (
|
from ..invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -445,7 +446,7 @@ class Graph(BaseModel):
|
|||||||
node = graph.nodes[node_id]
|
node = graph.nodes[node_id]
|
||||||
|
|
||||||
# Ensure the node type matches the new node
|
# Ensure the node type matches the new node
|
||||||
if type(node) != type(new_node):
|
if type(node) is not type(new_node):
|
||||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||||
|
|
||||||
# Ensure the new id is either the same or is not in the graph
|
# Ensure the new id is either the same or is not in the graph
|
||||||
@ -632,7 +633,7 @@ class Graph(BaseModel):
|
|||||||
[
|
[
|
||||||
t
|
t
|
||||||
for input_field in input_fields
|
for input_field in input_fields
|
||||||
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||||
if t != NoneType
|
if t != NoneType
|
||||||
]
|
]
|
||||||
) # Get unique types
|
) # Get unique types
|
||||||
@ -923,7 +924,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if next_node_id == None:
|
if next_node_id is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get all parents of the next node
|
# Get all parents of the next node
|
||||||
|
@ -179,7 +179,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||||
if not image_name in self.__cache:
|
if image_name not in self.__cache:
|
||||||
self.__cache[image_name] = image
|
self.__cache[image_name] = image
|
||||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||||
if len(self.__cache) > self.__max_cache_size:
|
if len(self.__cache) > self.__max_cache_size:
|
||||||
|
@ -282,7 +282,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
SELECT images.metadata FROM images
|
SELECT images.metadata FROM images
|
||||||
WHERE image_name = ?;
|
WHERE image_name = ?;
|
||||||
""",
|
""",
|
||||||
@ -309,7 +309,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Change the category of the image
|
# Change the category of the image
|
||||||
if changes.image_category is not None:
|
if changes.image_category is not None:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
UPDATE images
|
UPDATE images
|
||||||
SET image_category = ?
|
SET image_category = ?
|
||||||
WHERE image_name = ?;
|
WHERE image_name = ?;
|
||||||
@ -320,7 +320,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Change the session associated with the image
|
# Change the session associated with the image
|
||||||
if changes.session_id is not None:
|
if changes.session_id is not None:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
UPDATE images
|
UPDATE images
|
||||||
SET session_id = ?
|
SET session_id = ?
|
||||||
WHERE image_name = ?;
|
WHERE image_name = ?;
|
||||||
@ -331,7 +331,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Change the image's `is_intermediate`` flag
|
# Change the image's `is_intermediate`` flag
|
||||||
if changes.is_intermediate is not None:
|
if changes.is_intermediate is not None:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
UPDATE images
|
UPDATE images
|
||||||
SET is_intermediate = ?
|
SET is_intermediate = ?
|
||||||
WHERE image_name = ?;
|
WHERE image_name = ?;
|
||||||
@ -342,7 +342,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Change the image's `starred`` state
|
# Change the image's `starred`` state
|
||||||
if changes.starred is not None:
|
if changes.starred is not None:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
UPDATE images
|
UPDATE images
|
||||||
SET starred = ?
|
SET starred = ?
|
||||||
WHERE image_name = ?;
|
WHERE image_name = ?;
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
@ -379,10 +378,10 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete(image_name)
|
self._services.image_records.delete(image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image record")
|
self._services.logger.error("Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
except ImageFileDeleteException:
|
except ImageFileDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image file")
|
self._services.logger.error("Failed to delete image file")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image record and file")
|
self._services.logger.error("Problem deleting image record and file")
|
||||||
@ -395,10 +394,10 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete_many(image_names)
|
self._services.image_records.delete_many(image_names)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image records")
|
self._services.logger.error("Failed to delete image records")
|
||||||
raise
|
raise
|
||||||
except ImageFileDeleteException:
|
except ImageFileDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image files")
|
self._services.logger.error("Failed to delete image files")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image records and files")
|
self._services.logger.error("Problem deleting image records and files")
|
||||||
@ -412,10 +411,10 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
return count
|
return count
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image records")
|
self._services.logger.error("Failed to delete image records")
|
||||||
raise
|
raise
|
||||||
except ImageFileDeleteException:
|
except ImageFileDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image files")
|
self._services.logger.error("Failed to delete image files")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image records and files")
|
self._services.logger.error("Problem deleting image records and files")
|
||||||
|
@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||||
|
|
||||||
"""
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
statistics = InvocationStatsService(graph_execution_manager)
|
statistics = InvocationStatsService(graph_execution_manager)
|
||||||
|
@ -60,7 +60,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
return None if name not in self.__cache else self.__cache[name]
|
return None if name not in self.__cache else self.__cache[name]
|
||||||
|
|
||||||
def __set_cache(self, name: str, data: torch.Tensor):
|
def __set_cache(self, name: str, data: torch.Tensor):
|
||||||
if not name in self.__cache:
|
if name not in self.__cache:
|
||||||
self.__cache[name] = data
|
self.__cache[name] = data
|
||||||
self.__cache_ids.put(name)
|
self.__cache_ids.put(name)
|
||||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from typing import Union
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
@ -5,7 +6,7 @@ from PIL import Image
|
|||||||
from diffusers.utils import PIL_INTERPOLATION
|
from diffusers.utils import PIL_INTERPOLATION
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from controlnet_aux.util import HWC3, resize_image
|
from controlnet_aux.util import HWC3
|
||||||
|
|
||||||
###################################################################
|
###################################################################
|
||||||
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
||||||
@ -232,7 +233,8 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
|||||||
k0 = float(h) / old_h
|
k0 = float(h) / old_h
|
||||||
k1 = float(w) / old_w
|
k1 = float(w) / old_w
|
||||||
|
|
||||||
safeint = lambda x: int(np.round(x))
|
def safeint(x: Union[int, float]) -> int:
|
||||||
|
return int(np.round(x))
|
||||||
|
|
||||||
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||||
if resize_mode == "fill_resize": # OUTER_FIT
|
if resize_mode == "fill_resize": # OUTER_FIT
|
||||||
|
@ -5,7 +5,6 @@ from invokeai.app.models.image import ProgressImage
|
|||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo # noqa: F401
|
||||||
from .model_management.models import SilenceWarnings
|
from .model_management.models import SilenceWarnings # noqa: F401
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.image_util methods.
|
Initialization file for invokeai.backend.image_util methods.
|
||||||
"""
|
"""
|
||||||
from .patchmatch import PatchMatch
|
from .patchmatch import PatchMatch # noqa: F401
|
||||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata
|
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||||
from .seamless import configure_model_padding
|
from .seamless import configure_model_padding # noqa: F401
|
||||||
from .txt2mask import Txt2Mask
|
from .txt2mask import Txt2Mask # noqa: F401
|
||||||
from .util import InitImageResizer, make_grid
|
from .util import InitImageResizer, make_grid # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
||||||
|
from PIL import ImageDraw
|
||||||
|
|
||||||
if not debug_status:
|
if not debug_status:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class PngWriter:
|
|||||||
dirlist = sorted(os.listdir(self.outdir), reverse=True)
|
dirlist = sorted(os.listdir(self.outdir), reverse=True)
|
||||||
# find the first filename that matches our pattern or return 000000.0.png
|
# find the first filename that matches our pattern or return 000000.0.png
|
||||||
existing_name = next(
|
existing_name = next(
|
||||||
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
|
(f for f in dirlist if re.match(r"^(\d+)\..*\.png", f)),
|
||||||
"0000000.0.png",
|
"0000000.0.png",
|
||||||
)
|
)
|
||||||
basecount = int(existing_name.split(".", 1)[0]) + 1
|
basecount = int(existing_name.split(".", 1)[0]) + 1
|
||||||
@ -98,11 +98,11 @@ class PromptFormatter:
|
|||||||
# to do: put model name into the t2i object
|
# to do: put model name into the t2i object
|
||||||
# switches.append(f'--model{t2i.model_name}')
|
# switches.append(f'--model{t2i.model_name}')
|
||||||
if opt.seamless or t2i.seamless:
|
if opt.seamless or t2i.seamless:
|
||||||
switches.append(f"--seamless")
|
switches.append("--seamless")
|
||||||
if opt.init_img:
|
if opt.init_img:
|
||||||
switches.append(f"-I{opt.init_img}")
|
switches.append(f"-I{opt.init_img}")
|
||||||
if opt.fit:
|
if opt.fit:
|
||||||
switches.append(f"--fit")
|
switches.append("--fit")
|
||||||
if opt.strength and opt.init_img is not None:
|
if opt.strength and opt.init_img is not None:
|
||||||
switches.append(f"-f{opt.strength or t2i.strength}")
|
switches.append(f"-f{opt.strength or t2i.strength}")
|
||||||
if opt.gfpgan_strength:
|
if opt.gfpgan_strength:
|
||||||
|
@ -52,7 +52,6 @@ from invokeai.frontend.install.widgets import (
|
|||||||
SingleSelectColumns,
|
SingleSelectColumns,
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
FileBox,
|
FileBox,
|
||||||
IntTitleSlider,
|
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
CyclingForm,
|
CyclingForm,
|
||||||
MIN_COLS,
|
MIN_COLS,
|
||||||
|
@ -116,7 +116,7 @@ class MigrateTo3(object):
|
|||||||
appropriate location within the destination models directory.
|
appropriate location within the destination models directory.
|
||||||
"""
|
"""
|
||||||
directories_scanned = set()
|
directories_scanned = set()
|
||||||
for root, dirs, files in os.walk(src_dir):
|
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
try:
|
try:
|
||||||
model = Path(root, d)
|
model = Path(root, d)
|
||||||
@ -525,7 +525,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
|||||||
if version_3: # write into the dest directory
|
if version_3: # write into the dest directory
|
||||||
try:
|
try:
|
||||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||||
except:
|
except Exception:
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
MigrateTo3.initialize_yaml(config_file)
|
||||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||||
(dest_directory / "models").replace(dest_models)
|
(dest_directory / "models").replace(dest_models)
|
||||||
|
@ -12,7 +12,6 @@ from typing import Optional, List, Dict, Callable, Union, Set
|
|||||||
import requests
|
import requests
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
import onnx
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType # noqa: F401
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache # noqa: F401
|
||||||
from .lora import ModelPatcher, ONNXModelPatcher
|
from .lora import ModelPatcher, ONNXModelPatcher # noqa: F401
|
||||||
from .models import (
|
from .models import ( # noqa: F401
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
@ -12,5 +12,4 @@ from .models import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
)
|
)
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod # noqa: F401
|
||||||
from .lora import ModelPatcher
|
|
||||||
|
@ -5,21 +5,16 @@ from contextlib import contextmanager
|
|||||||
from typing import Optional, Dict, Tuple, Any, Union, List
|
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from torch.utils.hooks import RemovableHandle
|
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
|
||||||
from transformers import CLIPTextModel
|
|
||||||
from onnx import numpy_helper
|
|
||||||
from onnxruntime import OrtValue
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from .models.lora import LoRAModel
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
(lora_model1, 0.7),
|
(lora_model1, 0.7),
|
||||||
@ -52,7 +47,7 @@ class ModelPatcher:
|
|||||||
module = module.get_submodule(submodule_name)
|
module = module.get_submodule(submodule_name)
|
||||||
module_key += "." + submodule_name
|
module_key += "." + submodule_name
|
||||||
submodule_name = key_parts.pop(0)
|
submodule_name = key_parts.pop(0)
|
||||||
except:
|
except Exception:
|
||||||
submodule_name += "_" + key_parts.pop(0)
|
submodule_name += "_" + key_parts.pop(0)
|
||||||
|
|
||||||
module = module.get_submodule(submodule_name)
|
module = module.get_submodule(submodule_name)
|
||||||
@ -312,7 +307,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
|
|
||||||
class ONNXModelPatcher:
|
class ONNXModelPatcher:
|
||||||
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
from diffusers import OnnxRuntimeModel
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -341,7 +337,7 @@ class ONNXModelPatcher:
|
|||||||
def apply_lora(
|
def apply_lora(
|
||||||
cls,
|
cls,
|
||||||
model: IAIOnnxRuntimeModel,
|
model: IAIOnnxRuntimeModel,
|
||||||
loras: List[Tuple[LoraModel, float]],
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
):
|
):
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
@ -273,7 +273,7 @@ class ModelCache(object):
|
|||||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -419,12 +419,12 @@ class ModelManager(object):
|
|||||||
base_model_str, model_type_str, model_name = model_key.split("/", 2)
|
base_model_str, model_type_str, model_name = model_key.split("/", 2)
|
||||||
try:
|
try:
|
||||||
model_type = ModelType(model_type_str)
|
model_type = ModelType(model_type_str)
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(f"Unknown model type: {model_type_str}")
|
raise Exception(f"Unknown model type: {model_type_str}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
base_model = BaseModelType(base_model_str)
|
base_model = BaseModelType(base_model_str)
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(f"Unknown base model: {base_model_str}")
|
raise Exception(f"Unknown base model: {base_model_str}")
|
||||||
|
|
||||||
return (model_name, base_model, model_type)
|
return (model_name, base_model, model_type)
|
||||||
@ -855,7 +855,7 @@ class ModelManager(object):
|
|||||||
info.pop("config")
|
info.pop("config")
|
||||||
|
|
||||||
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
|
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
|
||||||
except:
|
except Exception:
|
||||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||||
rmtree(new_diffusers_path)
|
rmtree(new_diffusers_path)
|
||||||
raise
|
raise
|
||||||
@ -1042,7 +1042,7 @@ class ModelManager(object):
|
|||||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||||
try:
|
try:
|
||||||
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
installer = ModelInstall(
|
installer = ModelInstall(
|
||||||
|
@ -431,7 +431,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
return ModelVariantType.Depth
|
return ModelVariantType.Depth
|
||||||
elif in_channels == 4:
|
elif in_channels == 4:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ class ModelSearch(ABC):
|
|||||||
self.on_search_completed()
|
self.on_search_completed()
|
||||||
|
|
||||||
def walk_directory(self, path: Path):
|
def walk_directory(self, path: Path):
|
||||||
for root, dirs, files in os.walk(path):
|
for root, dirs, files in os.walk(path, followlinks=True):
|
||||||
if str(Path(root).name).startswith("."):
|
if str(Path(root).name).startswith("."):
|
||||||
self._pruned_paths.add(root)
|
self._pruned_paths.add(root)
|
||||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||||
|
@ -2,7 +2,7 @@ import inspect
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Literal, get_origin
|
from typing import Literal, get_origin
|
||||||
from .base import (
|
from .base import ( # noqa: F401
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
@ -118,7 +118,7 @@ def get_model_config_enums():
|
|||||||
fields = model_config.__annotations__
|
fields = model_config.__annotations__
|
||||||
try:
|
try:
|
||||||
field = fields["model_format"]
|
field = fields["model_format"]
|
||||||
except:
|
except Exception:
|
||||||
raise Exception("format field not found")
|
raise Exception("format field not found")
|
||||||
|
|
||||||
# model_format: None
|
# model_format: None
|
||||||
|
@ -3,27 +3,28 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import inspect
|
import inspect
|
||||||
from enum import Enum
|
import warnings
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from contextlib import suppress
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors.torch
|
|
||||||
from pathlib import Path
|
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
|
||||||
|
|
||||||
from contextlib import suppress
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
|
import safetensors.torch
|
||||||
|
from diffusers import DiffusionPipeline, ConfigMixin
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnxruntime import (
|
from onnxruntime import (
|
||||||
InferenceSession,
|
InferenceSession,
|
||||||
SessionOptions,
|
SessionOptions,
|
||||||
get_available_providers,
|
get_available_providers,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
from diffusers import logging as diffusers_logging
|
||||||
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
@ -171,7 +172,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
fields = value.__annotations__
|
fields = value.__annotations__
|
||||||
try:
|
try:
|
||||||
field = fields["model_format"]
|
field = fields["model_format"]
|
||||||
except:
|
except Exception:
|
||||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||||
|
|
||||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||||
@ -244,7 +245,7 @@ class DiffusersModel(ModelBase):
|
|||||||
try:
|
try:
|
||||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||||
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||||
except:
|
except Exception:
|
||||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||||
|
|
||||||
config_data.pop("_ignore_files", None)
|
config_data.pop("_ignore_files", None)
|
||||||
@ -343,7 +344,7 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
|||||||
with open(os.path.join(model_path, file), "r") as f:
|
with open(os.path.join(model_path, file), "r") as f:
|
||||||
index_data = json.loads(f.read())
|
index_data = json.loads(f.read())
|
||||||
return int(index_data["metadata"]["total_size"])
|
return int(index_data["metadata"]["total_size"])
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# calculate files size if there is no index file
|
# calculate files size if there is no index file
|
||||||
@ -440,7 +441,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
|||||||
if str(path).endswith(".safetensors"):
|
if str(path).endswith(".safetensors"):
|
||||||
try:
|
try:
|
||||||
checkpoint = _fast_safetensors_reader(path)
|
checkpoint = _fast_safetensors_reader(path)
|
||||||
except:
|
except Exception:
|
||||||
# TODO: create issue for support "meta"?
|
# TODO: create issue for support "meta"?
|
||||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||||
else:
|
else:
|
||||||
@ -452,11 +453,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
|||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from diffusers import logging as diffusers_logging
|
|
||||||
from transformers import logging as transformers_logging
|
|
||||||
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
class SilenceWarnings(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||||
@ -639,7 +635,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
raise Exception("You should call create_session before running model")
|
raise Exception("You should call create_session before running model")
|
||||||
|
|
||||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||||
output_names = self.session.get_outputs()
|
# output_names = self.session.get_outputs()
|
||||||
# for k in inputs:
|
# for k in inputs:
|
||||||
# self.io_binding.bind_cpu_input(k, inputs[k])
|
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||||
# for name in output_names:
|
# for name in output_names:
|
||||||
|
@ -43,7 +43,7 @@ class ControlNetModel(ModelBase):
|
|||||||
try:
|
try:
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
except:
|
except Exception:
|
||||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||||
|
|
||||||
model_class_name = config.get("_class_name", None)
|
model_class_name = config.get("_class_name", None)
|
||||||
@ -53,7 +53,7 @@ class ControlNetModel(ModelBase):
|
|||||||
try:
|
try:
|
||||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||||
except:
|
except Exception:
|
||||||
raise Exception("Invalid ControlNet model!")
|
raise Exception("Invalid ControlNet model!")
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
@ -78,7 +78,7 @@ class ControlNetModel(ModelBase):
|
|||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if not model:
|
if not model:
|
||||||
raise ModelNotFoundException()
|
raise ModelNotFoundException()
|
||||||
|
@ -330,5 +330,5 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|||||||
config_path = config_path.relative_to(app_config.root_path)
|
config_path = config_path.relative_to(app_config.root_path)
|
||||||
return str(config_path)
|
return str(config_path)
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
@ -1,25 +1,17 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import Field
|
from typing import Literal
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal, Optional, Union
|
from diffusers import OnnxRuntimeModel
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
DiffusersModel,
|
DiffusersModel,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
SilenceWarnings,
|
|
||||||
read_checkpoint_meta,
|
|
||||||
classproperty,
|
classproperty,
|
||||||
OnnxRuntimeModel,
|
|
||||||
IAIOnnxRuntimeModel,
|
IAIOnnxRuntimeModel,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionOnnxModelFormat(str, Enum):
|
class StableDiffusionOnnxModelFormat(str, Enum):
|
||||||
|
@ -44,14 +44,14 @@ class VaeModel(ModelBase):
|
|||||||
try:
|
try:
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
except:
|
except Exception:
|
||||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||||
except:
|
except Exception:
|
||||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for the invokeai.backend.stable_diffusion package
|
Initialization file for the invokeai.backend.stable_diffusion package
|
||||||
"""
|
"""
|
||||||
from .diffusers_pipeline import (
|
from .diffusers_pipeline import ( # noqa: F401
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
PipelineIntermediateState,
|
PipelineIntermediateState,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from .diffusion import InvokeAIDiffuserComponent
|
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
|
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
||||||
|
PostprocessingSettings,
|
||||||
|
BasicConditioningInfo,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
)
|
||||||
|
@ -2,10 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
|
||||||
import secrets
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import einops
|
import einops
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.models.diffusion
|
Initialization file for invokeai.models.diffusion
|
||||||
"""
|
"""
|
||||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||||
from .shared_invokeai_diffusion import (
|
from .shared_invokeai_diffusion import ( # noqa: F401
|
||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
@ -12,6 +13,11 @@ import torch
|
|||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
Attention,
|
||||||
|
AttnProcessor,
|
||||||
|
SlicedAttnProcessor,
|
||||||
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
@ -522,14 +528,6 @@ class AttnProcessor:
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers.models.attention_processor import (
|
|
||||||
Attention,
|
|
||||||
AttnProcessor,
|
|
||||||
SlicedAttnProcessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -5,8 +5,6 @@ import torch
|
|||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from .cross_attention_control import CrossAttentionType, get_cross_attention_modules
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionMapSaver:
|
class AttentionMapSaver:
|
||||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||||
|
@ -3,15 +3,12 @@ from __future__ import annotations
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, Dict, Optional, Union, List
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
@ -579,7 +576,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
latents.to(device="cpu")
|
latents.to(device="cpu")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
h_symmetry_time_pct != None
|
h_symmetry_time_pct is not None
|
||||||
and self.last_percent_through < h_symmetry_time_pct
|
and self.last_percent_through < h_symmetry_time_pct
|
||||||
and percent_through >= h_symmetry_time_pct
|
and percent_through >= h_symmetry_time_pct
|
||||||
):
|
):
|
||||||
@ -595,7 +592,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
v_symmetry_time_pct != None
|
v_symmetry_time_pct is not None
|
||||||
and self.last_percent_through < v_symmetry_time_pct
|
and self.last_percent_through < v_symmetry_time_pct
|
||||||
and percent_through >= v_symmetry_time_pct
|
and percent_through >= v_symmetry_time_pct
|
||||||
):
|
):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from ldm.modules.image_degradation.bsrgan import (
|
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
|
||||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
degradation_bsrgan_variant as degradation_fn_bsr,
|
||||||
)
|
)
|
||||||
from ldm.modules.image_degradation.bsrgan_light import (
|
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
|
||||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
||||||
)
|
)
|
||||||
|
@ -573,14 +573,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||||
"""
|
"""
|
||||||
image = util.uint2single(image)
|
image = util.uint2single(image)
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||||
sf_ori = sf
|
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||||
|
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
h1, w1 = image.shape[:2]
|
||||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
hq = image.copy()
|
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||||
if np.random.rand() < 0.5:
|
if np.random.rand() < 0.5:
|
||||||
@ -777,7 +778,7 @@ if __name__ == "__main__":
|
|||||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||||
print(img_lq.shape)
|
print(img_lq.shape)
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
print("bicubic", img_lq_bicubic.shape)
|
||||||
print(img_hq.shape)
|
# print(img_hq.shape)
|
||||||
lq_nearest = cv2.resize(
|
lq_nearest = cv2.resize(
|
||||||
util.single2uint(img_lq),
|
util.single2uint(img_lq),
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||||
@ -788,5 +789,6 @@ if __name__ == "__main__":
|
|||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||||
interpolation=0,
|
interpolation=0,
|
||||||
)
|
)
|
||||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||||
|
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
|
||||||
util.imsave(img_concat, str(i) + ".png")
|
util.imsave(img_concat, str(i) + ".png")
|
||||||
|
@ -577,14 +577,15 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||||
"""
|
"""
|
||||||
image = util.uint2single(image)
|
image = util.uint2single(image)
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||||
sf_ori = sf
|
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||||
|
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
h1, w1 = image.shape[:2]
|
||||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
hq = image.copy()
|
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||||
if np.random.rand() < 0.5:
|
if np.random.rand() < 0.5:
|
||||||
|
@ -8,8 +8,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
@ -50,6 +48,8 @@ def get_timestamp():
|
|||||||
|
|
||||||
|
|
||||||
def imshow(x, title=None, cbar=False, figsize=None):
|
def imshow(x, title=None, cbar=False, figsize=None):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
plt.figure(figsize=figsize)
|
plt.figure(figsize=figsize)
|
||||||
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
||||||
if title:
|
if title:
|
||||||
@ -60,6 +60,8 @@ def imshow(x, title=None, cbar=False, figsize=None):
|
|||||||
|
|
||||||
|
|
||||||
def surf(Z, cmap="rainbow", figsize=None):
|
def surf(Z, cmap="rainbow", figsize=None):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
plt.figure(figsize=figsize)
|
plt.figure(figsize=figsize)
|
||||||
ax3 = plt.axes(projection="3d")
|
ax3 = plt.axes(projection="3d")
|
||||||
|
|
||||||
@ -89,7 +91,7 @@ def get_image_paths(dataroot):
|
|||||||
def _get_paths_from_images(path):
|
def _get_paths_from_images(path):
|
||||||
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
|
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
|
||||||
images = []
|
images = []
|
||||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
|
||||||
for fname in sorted(fnames):
|
for fname in sorted(fnames):
|
||||||
if is_image_file(fname):
|
if is_image_file(fname):
|
||||||
img_path = os.path.join(dirpath, fname)
|
img_path = os.path.join(dirpath, fname)
|
||||||
|
@ -1 +1 @@
|
|||||||
from .schedulers import SCHEDULER_MAP
|
from .schedulers import SCHEDULER_MAP # noqa: F401
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.training
|
Initialization file for invokeai.backend.training
|
||||||
"""
|
"""
|
||||||
from .textual_inversion_training import do_textual_inversion_training, parse_args
|
from .textual_inversion_training import do_textual_inversion_training, parse_args # noqa: F401
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.util
|
Initialization file for invokeai.backend.util
|
||||||
"""
|
"""
|
||||||
from .devices import (
|
from .devices import ( # noqa: F401
|
||||||
CPU_DEVICE,
|
CPU_DEVICE,
|
||||||
CUDA_DEVICE,
|
CUDA_DEVICE,
|
||||||
MPS_DEVICE,
|
MPS_DEVICE,
|
||||||
@ -10,5 +10,5 @@ from .devices import (
|
|||||||
normalize_device,
|
normalize_device,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
)
|
)
|
||||||
from .log import write_log
|
from .log import write_log # noqa: F401
|
||||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir
|
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401
|
||||||
|
@ -25,10 +25,15 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
|||||||
import diffusers
|
import diffusers
|
||||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
|
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
# TODO: create PR to diffusers
|
# TODO: create PR to diffusers
|
||||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||||
|
|
||||||
|
|
||||||
|
logger = InvokeAILogger.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||||
"""
|
"""
|
||||||
A ControlNet model.
|
A ControlNet model.
|
||||||
@ -111,7 +116,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
"DownBlock2D",
|
"DownBlock2D",
|
||||||
),
|
),
|
||||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
layers_per_block: int = 2,
|
layers_per_block: int = 2,
|
||||||
downsample_padding: int = 1,
|
downsample_padding: int = 1,
|
||||||
mid_block_scale_factor: float = 1,
|
mid_block_scale_factor: float = 1,
|
||||||
|
@ -27,8 +27,8 @@ def write_log_message(results, output_cntr):
|
|||||||
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||||
if len(log_lines) > 1:
|
if len(log_lines) > 1:
|
||||||
subcntr = 1
|
subcntr = 1
|
||||||
for l in log_lines:
|
for ll in log_lines:
|
||||||
print(f"[{output_cntr}.{subcntr}] {l}", end="")
|
print(f"[{output_cntr}.{subcntr}] {ll}", end="")
|
||||||
subcntr += 1
|
subcntr += 1
|
||||||
else:
|
else:
|
||||||
print(f"[{output_cntr}] {log_lines[0]}", end="")
|
print(f"[{output_cntr}] {log_lines[0]}", end="")
|
||||||
|
@ -182,13 +182,13 @@ import urllib.parse
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import syslog
|
import syslog
|
||||||
|
|
||||||
SYSLOG_AVAILABLE = True
|
SYSLOG_AVAILABLE = True
|
||||||
except:
|
except ImportError:
|
||||||
SYSLOG_AVAILABLE = False
|
SYSLOG_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@ -417,7 +417,7 @@ class InvokeAILogger(object):
|
|||||||
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
|
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
|
||||||
else:
|
else:
|
||||||
syslog_args["address"] = arg_name
|
syslog_args["address"] = arg_name
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError(f"{args} is not a value argument list for syslog logging")
|
raise ValueError(f"{args} is not a value argument list for syslog logging")
|
||||||
return logging.handlers.SysLogHandler(**syslog_args)
|
return logging.handlers.SysLogHandler(**syslog_args)
|
||||||
|
|
||||||
|
@ -191,7 +191,7 @@ class ChunkedSlicedAttnProcessor:
|
|||||||
assert value.shape[0] == 1
|
assert value.shape[0] == 1
|
||||||
assert hidden_states.shape[0] == 1
|
assert hidden_states.shape[0] == 1
|
||||||
|
|
||||||
dtype = query.dtype
|
# dtype = query.dtype
|
||||||
if attn.upcast_attention:
|
if attn.upcast_attention:
|
||||||
query = query.float()
|
query = query.float()
|
||||||
key = key.float()
|
key = key.float()
|
||||||
|
@ -84,7 +84,7 @@ def count_params(model, verbose=False):
|
|||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config, **kwargs):
|
def instantiate_from_config(config, **kwargs):
|
||||||
if not "target" in config:
|
if "target" not in config:
|
||||||
if config == "__is_first_stage__":
|
if config == "__is_first_stage__":
|
||||||
return None
|
return None
|
||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
@ -234,7 +234,8 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
|
|||||||
.repeat_interleave(d[1], 1)
|
.repeat_interleave(d[1], 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
dot = lambda grad, shift: (
|
def dot(grad, shift):
|
||||||
|
return (
|
||||||
torch.stack(
|
torch.stack(
|
||||||
(
|
(
|
||||||
grid[: shape[0], : shape[1], 0] + shift[0],
|
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||||
@ -287,7 +288,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
if dest.is_dir():
|
if dest.is_dir():
|
||||||
try:
|
try:
|
||||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||||
except:
|
except AttributeError:
|
||||||
file_name = os.path.basename(url)
|
file_name = os.path.basename(url)
|
||||||
dest = dest / file_name
|
dest = dest / file_name
|
||||||
else:
|
else:
|
||||||
@ -342,7 +343,7 @@ def url_attachment_name(url: str) -> dict:
|
|||||||
resp = requests.get(url, stream=True)
|
resp = requests.get(url, stream=True)
|
||||||
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
except:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.frontend.CLI
|
Initialization file for invokeai.frontend.CLI
|
||||||
"""
|
"""
|
||||||
from .CLI import main as invokeai_command_line_interface
|
from .CLI import main as invokeai_command_line_interface # noqa: F401
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||||
"""
|
"""
|
||||||
from ...backend.install.invokeai_configure import main as invokeai_configure
|
from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401
|
||||||
|
@ -80,7 +80,7 @@ def welcome(versions: dict):
|
|||||||
def get_extras():
|
def get_extras():
|
||||||
extras = ""
|
extras = ""
|
||||||
try:
|
try:
|
||||||
dist = pkg_resources.get_distribution("xformers")
|
_ = pkg_resources.get_distribution("xformers")
|
||||||
extras = "[xformers]"
|
extras = "[xformers]"
|
||||||
except pkg_resources.DistributionNotFound:
|
except pkg_resources.DistributionNotFound:
|
||||||
pass
|
pass
|
||||||
@ -90,7 +90,7 @@ def get_extras():
|
|||||||
def main():
|
def main():
|
||||||
versions = get_versions()
|
versions = get_versions()
|
||||||
if invokeai_is_running():
|
if invokeai_is_running():
|
||||||
print(f":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
||||||
input("Press any key to continue...")
|
input("Press any key to continue...")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -122,9 +122,9 @@ def main():
|
|||||||
print("")
|
print("")
|
||||||
print("")
|
print("")
|
||||||
if os.system(cmd) == 0:
|
if os.system(cmd) == 0:
|
||||||
print(f":heavy_check_mark: Upgrade successful")
|
print(":heavy_check_mark: Upgrade successful")
|
||||||
else:
|
else:
|
||||||
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
|
print(":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -251,7 +251,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
) -> dict[str, npyscreen.widget]:
|
) -> dict[str, npyscreen.widget]:
|
||||||
"""Generic code to create model selection widgets"""
|
"""Generic code to create model selection widgets"""
|
||||||
widgets = dict()
|
widgets = dict()
|
||||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and not x in exclude]
|
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||||
model_labels = [self.model_labels[x] for x in model_list]
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
show_recommended = len(self.installed_models) == 0
|
show_recommended = len(self.installed_models) == 0
|
||||||
@ -357,14 +357,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
try:
|
try:
|
||||||
v.hidden = True
|
v.hidden = True
|
||||||
v.editable = False
|
v.editable = False
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
for k, v in widgets[selected_tab].items():
|
for k, v in widgets[selected_tab].items():
|
||||||
try:
|
try:
|
||||||
v.hidden = False
|
v.hidden = False
|
||||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||||
v.editable = True
|
v.editable = True
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self.__class__.current_tab = selected_tab # for persistence
|
self.__class__.current_tab = selected_tab # for persistence
|
||||||
self.display()
|
self.display()
|
||||||
@ -541,7 +541,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
self.ti_models,
|
self.ti_models,
|
||||||
]
|
]
|
||||||
for section in ui_sections:
|
for section in ui_sections:
|
||||||
if not "models_selected" in section:
|
if "models_selected" not in section:
|
||||||
continue
|
continue
|
||||||
selected = set([section["models"][x] for x in section["models_selected"].value])
|
selected = set([section["models"][x] for x in section["models_selected"].value])
|
||||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||||
@ -637,7 +637,7 @@ def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPre
|
|||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return response
|
return response
|
||||||
except:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -673,8 +673,7 @@ def process_and_execute(
|
|||||||
def select_and_download_models(opt: Namespace):
|
def select_and_download_models(opt: Namespace):
|
||||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||||
config.precision = precision
|
config.precision = precision
|
||||||
helper = lambda x: ask_user_for_prediction_type(x)
|
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
||||||
installer = ModelInstall(config, prediction_type_helper=helper)
|
|
||||||
if opt.list_models:
|
if opt.list_models:
|
||||||
installer.list_models(opt.list_models)
|
installer.list_models(opt.list_models)
|
||||||
elif opt.add or opt.delete:
|
elif opt.add or opt.delete:
|
||||||
|
@ -102,8 +102,8 @@ def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
|
|||||||
class IntSlider(npyscreen.Slider):
|
class IntSlider(npyscreen.Slider):
|
||||||
def translate_value(self):
|
def translate_value(self):
|
||||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||||
l = (len(str(self.out_of))) * 2 + 4
|
length = (len(str(self.out_of))) * 2 + 4
|
||||||
stri = stri.rjust(l)
|
stri = stri.rjust(length)
|
||||||
return stri
|
return stri
|
||||||
|
|
||||||
|
|
||||||
@ -167,8 +167,8 @@ class FloatSlider(npyscreen.Slider):
|
|||||||
# this is supposed to adjust display precision, but doesn't
|
# this is supposed to adjust display precision, but doesn't
|
||||||
def translate_value(self):
|
def translate_value(self):
|
||||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||||
l = (len(str(self.out_of))) * 2 + 4
|
length = (len(str(self.out_of))) * 2 + 4
|
||||||
stri = stri.rjust(l)
|
stri = stri.rjust(length)
|
||||||
return stri
|
return stri
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.frontend.merge
|
Initialization file for invokeai.frontend.merge
|
||||||
"""
|
"""
|
||||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401
|
||||||
|
@ -9,19 +9,15 @@ import curses
|
|||||||
import sys
|
import sys
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from diffusers import logging as dlogging
|
|
||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
ModelMerger,
|
ModelMerger,
|
||||||
MergeInterpolationMethod,
|
|
||||||
ModelManager,
|
ModelManager,
|
||||||
ModelType,
|
ModelType,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -318,7 +314,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self, base_model: BaseModelType = None) -> List[str]:
|
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
|
||||||
model_names = [
|
model_names = [
|
||||||
info["model_name"]
|
info["model_name"]
|
||||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.frontend.training
|
Initialization file for invokeai.frontend.training
|
||||||
"""
|
"""
|
||||||
from .textual_inversion import main as invokeai_textual_inversion
|
from .textual_inversion import main as invokeai_textual_inversion # noqa: F401
|
||||||
|
@ -59,7 +59,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
default = self.model_names.index(saved_args["model"])
|
default = self.model_names.index(saved_args["model"])
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
@ -377,7 +377,7 @@ def previous_args() -> dict:
|
|||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(conf_file)
|
conf = OmegaConf.load(conf_file)
|
||||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||||
except:
|
except Exception:
|
||||||
conf = None
|
conf = None
|
||||||
|
|
||||||
return conf
|
return conf
|
||||||
|
@ -4,7 +4,7 @@ import { InvokeLogLevel } from 'app/logging/logger';
|
|||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { startCase } from 'lodash-es';
|
import { startCase, upperFirst } from 'lodash-es';
|
||||||
import { LogLevelName } from 'roarr';
|
import { LogLevelName } from 'roarr';
|
||||||
import {
|
import {
|
||||||
isAnySessionRejected,
|
isAnySessionRejected,
|
||||||
@ -26,6 +26,7 @@ import {
|
|||||||
import { ProgressImage } from 'services/events/types';
|
import { ProgressImage } from 'services/events/types';
|
||||||
import { makeToast } from '../util/makeToast';
|
import { makeToast } from '../util/makeToast';
|
||||||
import { LANGUAGES } from './constants';
|
import { LANGUAGES } from './constants';
|
||||||
|
import { zPydanticValidationError } from './zodSchemas';
|
||||||
|
|
||||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||||
|
|
||||||
@ -361,9 +362,24 @@ export const systemSlice = createSlice({
|
|||||||
state.progressImage = null;
|
state.progressImage = null;
|
||||||
|
|
||||||
let errorDescription = undefined;
|
let errorDescription = undefined;
|
||||||
|
const duration = 5000;
|
||||||
|
|
||||||
if (action.payload?.status === 422) {
|
if (action.payload?.status === 422) {
|
||||||
errorDescription = 'Validation Error';
|
const result = zPydanticValidationError.safeParse(action.payload);
|
||||||
|
if (result.success) {
|
||||||
|
result.data.error.detail.map((e) => {
|
||||||
|
state.toastQueue.push(
|
||||||
|
makeToast({
|
||||||
|
title: upperFirst(e.msg),
|
||||||
|
status: 'error',
|
||||||
|
description: `Path:
|
||||||
|
${e.loc.slice(3).join('.')}`,
|
||||||
|
duration,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
} else if (action.payload?.error) {
|
} else if (action.payload?.error) {
|
||||||
errorDescription = action.payload?.error as string;
|
errorDescription = action.payload?.error as string;
|
||||||
}
|
}
|
||||||
@ -373,6 +389,7 @@ export const systemSlice = createSlice({
|
|||||||
title: t('toast.serverError'),
|
title: t('toast.serverError'),
|
||||||
status: 'error',
|
status: 'error',
|
||||||
description: errorDescription,
|
description: errorDescription,
|
||||||
|
duration,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
import { z } from 'zod';
|
||||||
|
|
||||||
|
export const zPydanticValidationError = z.object({
|
||||||
|
status: z.literal(422),
|
||||||
|
error: z.object({
|
||||||
|
detail: z.array(
|
||||||
|
z.object({
|
||||||
|
loc: z.array(z.string()),
|
||||||
|
msg: z.string(),
|
||||||
|
type: z.string(),
|
||||||
|
})
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
});
|
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
initialization file for invokeai
|
initialization file for invokeai
|
||||||
"""
|
"""
|
||||||
from .invokeai_version import __version__
|
from .invokeai_version import __version__ # noqa: F401
|
||||||
|
|
||||||
__app_id__ = "invoke-ai/InvokeAI"
|
__app_id__ = "invoke-ai/InvokeAI"
|
||||||
__app_name__ = "InvokeAI"
|
__app_name__ = "InvokeAI"
|
||||||
|
@ -8,9 +8,8 @@ from google.colab import files
|
|||||||
from IPython.display import Image as ipyimg
|
from IPython.display import Image as ipyimg
|
||||||
import ipywidgets as widgets
|
import ipywidgets as widgets
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from numpy import asarray
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import torch, torchvision
|
import torchvision
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.util import ismap
|
from ldm.util import ismap
|
||||||
import time
|
import time
|
||||||
@ -68,14 +67,14 @@ def get_custom_cond(mode):
|
|||||||
|
|
||||||
elif mode == "text_conditional":
|
elif mode == "text_conditional":
|
||||||
w = widgets.Text(value="A cake with cream!", disabled=True)
|
w = widgets.Text(value="A cake with cream!", disabled=True)
|
||||||
display(w)
|
display(w) # noqa: F821
|
||||||
|
|
||||||
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
|
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
|
||||||
f.write(w.value)
|
f.write(w.value)
|
||||||
|
|
||||||
elif mode == "class_conditional":
|
elif mode == "class_conditional":
|
||||||
w = widgets.IntSlider(min=0, max=1000)
|
w = widgets.IntSlider(min=0, max=1000)
|
||||||
display(w)
|
display(w) # noqa: F821
|
||||||
with open(f"{dest}/{mode}/custom.txt", "w") as f:
|
with open(f"{dest}/{mode}/custom.txt", "w") as f:
|
||||||
f.write(w.value)
|
f.write(w.value)
|
||||||
|
|
||||||
@ -96,7 +95,7 @@ def select_cond_path(mode):
|
|||||||
onlyfiles = [f for f in sorted(os.listdir(path))]
|
onlyfiles = [f for f in sorted(os.listdir(path))]
|
||||||
|
|
||||||
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
|
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
|
||||||
display(selected)
|
display(selected) # noqa: F821
|
||||||
selected_path = os.path.join(path, selected.value)
|
selected_path = os.path.join(path, selected.value)
|
||||||
return selected_path
|
return selected_path
|
||||||
|
|
||||||
@ -123,7 +122,7 @@ def get_cond(mode, selected_path):
|
|||||||
|
|
||||||
|
|
||||||
def visualize_cond_img(path):
|
def visualize_cond_img(path):
|
||||||
display(ipyimg(filename=path))
|
display(ipyimg(filename=path)) # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
|
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
|
||||||
@ -331,7 +330,7 @@ def make_convolutional_sample(
|
|||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
log["sample_noquant"] = x_sample_noquant
|
log["sample_noquant"] = x_sample_noquant
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log["sample"] = x_sample
|
log["sample"] = x_sample
|
||||||
|
@ -95,7 +95,14 @@ dependencies = [
|
|||||||
"dev" = [
|
"dev" = [
|
||||||
"pudb",
|
"pudb",
|
||||||
]
|
]
|
||||||
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
|
"test" = [
|
||||||
|
"black",
|
||||||
|
"flake8",
|
||||||
|
"Flake8-pyproject",
|
||||||
|
"pytest>6.0.0",
|
||||||
|
"pytest-cov",
|
||||||
|
"pytest-datadir",
|
||||||
|
]
|
||||||
"xformers" = [
|
"xformers" = [
|
||||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
@ -185,6 +192,8 @@ output = "coverage/index.xml"
|
|||||||
|
|
||||||
[tool.flake8]
|
[tool.flake8]
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
|
ignore = ["E203", "E266", "E501", "W503"]
|
||||||
|
select = ["B", "C", "E", "F", "W", "T4"]
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
@ -4,7 +4,6 @@ Read a checkpoint/safetensors file and write out a template .json file containin
|
|||||||
its metadata for use in fast model probing.
|
its metadata for use in fast model probing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -3,11 +3,12 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from invokeai.app.cli_app import invoke_cli
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
|
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.app.cli_app import invoke_cli
|
|
||||||
|
|
||||||
invoke_cli()
|
invoke_cli()
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
|
"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image
|
||||||
|
|
||||||
if len(sys.argv) < 2:
|
if len(sys.argv) < 2:
|
||||||
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
||||||
|
@ -2,13 +2,11 @@
|
|||||||
|
|
||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Change working directory to the repo root
|
# Change working directory to the repo root
|
||||||
|
@ -2,13 +2,11 @@
|
|||||||
|
|
||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Change working directory to the repo root
|
# Change working directory to the repo root
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""make variations of input image"""
|
"""make variations of input image"""
|
||||||
|
|
||||||
import argparse, os, sys, glob
|
import argparse
|
||||||
|
import os
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,7 +13,6 @@ from einops import rearrange, repeat
|
|||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
import time
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
@ -234,7 +234,6 @@ def main():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope(device.type):
|
with precision_scope(device.type):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
tic = time.time()
|
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
@ -279,8 +278,6 @@ def main():
|
|||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
toc = time.time()
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import argparse, os, sys, glob
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import argparse, os, sys, glob
|
import argparse
|
||||||
import clip
|
import glob
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
import scann
|
import scann
|
||||||
import time
|
import time
|
||||||
@ -390,8 +390,8 @@ if __name__ == "__main__":
|
|||||||
grid = make_grid(grid, nrow=n_rows)
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
# to image
|
# to image
|
||||||
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
grid_np = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
Image.fromarray(grid_np.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
||||||
|
@ -1,24 +1,24 @@
|
|||||||
import argparse, os, sys, datetime, glob, importlib, csv
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from pytorch_lightning.trainer import Trainer
|
from pytorch_lightning.trainer import Trainer
|
||||||
from pytorch_lightning.callbacks import (
|
from pytorch_lightning.callbacks import Callback
|
||||||
ModelCheckpoint,
|
|
||||||
Callback,
|
|
||||||
LearningRateMonitor,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
from pytorch_lightning.utilities import rank_zero_info
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
|
|
||||||
@ -651,7 +651,7 @@ if __name__ == "__main__":
|
|||||||
trainer_config["accelerator"] = "auto"
|
trainer_config["accelerator"] = "auto"
|
||||||
for k in nondefault_trainer_args(opt):
|
for k in nondefault_trainer_args(opt):
|
||||||
trainer_config[k] = getattr(opt, k)
|
trainer_config[k] = getattr(opt, k)
|
||||||
if not "gpus" in trainer_config:
|
if "gpus" not in trainer_config:
|
||||||
del trainer_config["accelerator"]
|
del trainer_config["accelerator"]
|
||||||
cpu = True
|
cpu = True
|
||||||
else:
|
else:
|
||||||
@ -803,7 +803,7 @@ if __name__ == "__main__":
|
|||||||
trainer_opt.detect_anomaly = False
|
trainer_opt.detect_anomaly = False
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||||
trainer.logdir = logdir ###
|
trainer.logdir = logdir
|
||||||
|
|
||||||
# data
|
# data
|
||||||
config.data.params.train.params.data_root = opt.data_root
|
config.data.params.train.params.data_root = opt.data_root
|
||||||
|
@ -2,7 +2,7 @@ from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
|||||||
from ldm.modules.embedding_manager import EmbeddingManager
|
from ldm.modules.embedding_manager import EmbeddingManager
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
import argparse, os
|
import argparse
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -108,7 +108,7 @@ if __name__ == "__main__":
|
|||||||
manager.load(manager_ckpt)
|
manager.load(manager_ckpt)
|
||||||
|
|
||||||
for placeholder_string in manager.string_to_token_dict:
|
for placeholder_string in manager.string_to_token_dict:
|
||||||
if not placeholder_string in string_to_token_dict:
|
if placeholder_string not in string_to_token_dict:
|
||||||
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
|
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
|
||||||
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
|
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
|
||||||
|
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
import argparse, os, sys, glob, datetime, yaml
|
import argparse
|
||||||
import torch
|
import datetime
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
@ -10,7 +16,9 @@ from PIL import Image
|
|||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
rescale = lambda x: (x + 1.0) / 2.0
|
|
||||||
|
def rescale(x: float) -> float:
|
||||||
|
return (x + 1.0) / 2.0
|
||||||
|
|
||||||
|
|
||||||
def custom_to_pil(x):
|
def custom_to_pil(x):
|
||||||
@ -45,7 +53,7 @@ def logs2pil(logs, keys=["sample"]):
|
|||||||
else:
|
else:
|
||||||
print(f"Unknown format for key {k}. ")
|
print(f"Unknown format for key {k}. ")
|
||||||
img = None
|
img = None
|
||||||
except:
|
except Exception:
|
||||||
img = None
|
img = None
|
||||||
imgs[k] = img
|
imgs[k] = img
|
||||||
return imgs
|
return imgs
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scann
|
import scann
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import argparse, os, sys, glob
|
import argparse
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -7,10 +8,9 @@ from tqdm import tqdm, trange
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
import time
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -251,7 +251,6 @@ def main():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope(device.type):
|
with precision_scope(device.type):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
tic = time.time()
|
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
@ -310,8 +309,6 @@ def main():
|
|||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
toc = time.time()
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user