mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
9 Commits
next-test-
...
next-inpla
Author | SHA1 | Date | |
---|---|---|---|
6cb8f37a37 | |||
e5d9f33f7b | |||
5a87e7b3f8 | |||
f8b673dc85 | |||
cb8e0cbf35 | |||
33bd9da26c | |||
9190abd487 | |||
ff47334f22 | |||
a8c3efd98a |
@ -451,6 +451,7 @@ async def add_model_record(
|
|||||||
)
|
)
|
||||||
async def install_model(
|
async def install_model(
|
||||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||||
|
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||||
# TODO(MM2): Can we type this?
|
# TODO(MM2): Can we type this?
|
||||||
config: Optional[Dict[str, Any]] = Body(
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
@ -493,6 +494,7 @@ async def install_model(
|
|||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
|
inplace=bool(inplace),
|
||||||
)
|
)
|
||||||
logger.info(f"Started installation of {source}")
|
logger.info(f"Started installation of {source}")
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
# values from the command line or config file.
|
# values from the command line or config file.
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
@ -71,9 +72,25 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
|||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Add startup event to load dependencies
|
||||||
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||||
|
yield
|
||||||
|
# Shut down threads
|
||||||
|
ApiDependencies.shutdown()
|
||||||
|
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
app = FastAPI(
|
||||||
|
title="Invoke - Community Edition",
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
separate_input_output_schemas=False,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
# Add event handler
|
# Add event handler
|
||||||
event_handler_id: int = id(app)
|
event_handler_id: int = id(app)
|
||||||
@ -96,18 +113,6 @@ app.add_middleware(
|
|||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event() -> None:
|
|
||||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
# Shut down threads
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def shutdown_event() -> None:
|
|
||||||
ApiDependencies.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
# Include all routers
|
# Include all routers
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||||
|
@ -1,17 +1,11 @@
|
|||||||
from typing import Iterator, List, Optional, Tuple, Union
|
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||||
FieldDescriptions,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
OutputField,
|
|
||||||
UIComponent,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import torch_dtype
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
|
|
||||||
# unconditioned: Optional[torch.Tensor]
|
# unconditioned: Optional[torch.Tensor]
|
||||||
@ -149,7 +138,7 @@ class SDXLPromptInvocationBase:
|
|||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
text_encoder_model = text_encoder_info.model
|
text_encoder_model = text_encoder_info.model
|
||||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||||
):
|
):
|
||||||
assert isinstance(text_encoder, CLIPTextModel)
|
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -181,13 +181,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source: str,
|
source: str,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
|
inplace: bool = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
source_obj: Optional[StringLikeSource] = None
|
source_obj: Optional[StringLikeSource] = None
|
||||||
|
|
||||||
if Path(source).exists(): # A local file or directory
|
if Path(source).exists(): # A local file or directory
|
||||||
source_obj = LocalModelSource(path=Path(source))
|
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||||
elif match := re.match(hf_repoid_re, source):
|
elif match := re.match(hf_repoid_re, source):
|
||||||
source_obj = HFModelSource(
|
source_obj = HFModelSource(
|
||||||
repo_id=match.group(1),
|
repo_id=match.group(1),
|
||||||
|
@ -28,6 +28,7 @@ from typing import Callable, Optional, Set, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
default_logger: Logger = InvokeAILogger.get_logger()
|
default_logger: Logger = InvokeAILogger.get_logger()
|
||||||
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
models_found: Set[Path] = Field(default_factory=set)
|
models_found: Set[Path] = Field(default_factory=set)
|
||||||
scanned_dirs: Set[Path] = Field(default_factory=set)
|
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||||
pruned_paths: Set[Path] = Field(default_factory=set)
|
|
||||||
|
|
||||||
def search_started(self) -> None:
|
def search_started(self) -> None:
|
||||||
self.models_found = set()
|
self.models_found = set()
|
||||||
self.scanned_dirs = set()
|
|
||||||
self.pruned_paths = set()
|
|
||||||
if self.on_search_started:
|
if self.on_search_started:
|
||||||
self.on_search_started(self._directory)
|
self.on_search_started(self._directory)
|
||||||
|
|
||||||
@ -139,53 +137,53 @@ class ModelSearch(ModelSearchBase):
|
|||||||
|
|
||||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||||
self._directory = Path(directory)
|
self._directory = Path(directory)
|
||||||
|
if not self._directory.is_absolute():
|
||||||
|
self._directory = self.config.models_path / self._directory
|
||||||
self.stats = SearchStats() # zero out
|
self.stats = SearchStats() # zero out
|
||||||
self.search_started() # This will initialize _models_found to empty
|
self.search_started() # This will initialize _models_found to empty
|
||||||
self._walk_directory(directory)
|
self._walk_directory(self._directory)
|
||||||
self.search_completed()
|
self.search_completed()
|
||||||
return self.models_found
|
return self.models_found
|
||||||
|
|
||||||
def _walk_directory(self, path: Union[Path, str]) -> None:
|
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
|
||||||
for root, dirs, files in os.walk(path, followlinks=True):
|
absolute_path = Path(path)
|
||||||
# don't descend into directories that start with a "."
|
if (
|
||||||
# to avoid the Mac .DS_STORE issue.
|
len(absolute_path.parts) - len(self._directory.parts) > max_depth
|
||||||
if str(Path(root).name).startswith("."):
|
or not absolute_path.exists()
|
||||||
self.pruned_paths.add(Path(root))
|
or absolute_path.parent in self.models_found
|
||||||
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
):
|
||||||
continue
|
return
|
||||||
|
entries = os.scandir(absolute_path.as_posix())
|
||||||
|
entries = [entry for entry in entries if not entry.name.startswith(".")]
|
||||||
|
dirs = [entry for entry in entries if entry.is_dir()]
|
||||||
|
file_names = [entry.name for entry in entries if entry.is_file()]
|
||||||
|
if any(
|
||||||
|
x in file_names
|
||||||
|
for x in [
|
||||||
|
"config.json",
|
||||||
|
"model_index.json",
|
||||||
|
"learned_embeds.bin",
|
||||||
|
"pytorch_lora_weights.bin",
|
||||||
|
"image_encoder.txt",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
self.model_found(absolute_path)
|
||||||
|
return
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
||||||
|
return
|
||||||
|
|
||||||
self.stats.items_scanned += len(dirs) + len(files)
|
for n in file_names:
|
||||||
for d in dirs:
|
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
|
||||||
path = Path(root) / d
|
try:
|
||||||
if path.parent in self.scanned_dirs:
|
self.model_found(absolute_path / n)
|
||||||
self.scanned_dirs.add(path)
|
except KeyboardInterrupt:
|
||||||
continue
|
raise
|
||||||
if any(
|
except Exception as e:
|
||||||
(path / x).exists()
|
self.logger.warning(str(e))
|
||||||
for x in [
|
|
||||||
"config.json",
|
|
||||||
"model_index.json",
|
|
||||||
"learned_embeds.bin",
|
|
||||||
"pytorch_lora_weights.bin",
|
|
||||||
"image_encoder.txt",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
self.scanned_dirs.add(path)
|
|
||||||
try:
|
|
||||||
self.model_found(path)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
|
||||||
for f in files:
|
for d in dirs:
|
||||||
path = Path(root) / f
|
self._walk_directory(absolute_path / d)
|
||||||
if path.parent in self.scanned_dirs:
|
|
||||||
continue
|
|
||||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
|
||||||
try:
|
|
||||||
self.model_found(path)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
@ -4,12 +4,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
@ -168,7 +168,7 @@ class ModelPatcher:
|
|||||||
def apply_ti(
|
def apply_ti(
|
||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
||||||
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
||||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
@ -265,7 +265,7 @@ class ModelPatcher:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_clip_skip(
|
def apply_clip_skip(
|
||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
||||||
clip_skip: int,
|
clip_skip: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
skipped_layers = []
|
skipped_layers = []
|
||||||
|
Reference in New Issue
Block a user