mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
9 Commits
next-test-
...
next-inpla
Author | SHA1 | Date | |
---|---|---|---|
6cb8f37a37 | |||
e5d9f33f7b | |||
5a87e7b3f8 | |||
f8b673dc85 | |||
cb8e0cbf35 | |||
33bd9da26c | |||
9190abd487 | |||
ff47334f22 | |||
a8c3efd98a |
@ -451,6 +451,7 @@ async def add_model_record(
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
@ -493,6 +494,7 @@ async def install_model(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
|
@ -2,6 +2,7 @@
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@ -71,9 +72,25 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||
app = FastAPI(
|
||||
title="Invoke - Community Edition",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
separate_input_output_schemas=False,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
@ -96,18 +113,6 @@ app.add_middleware(
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
|
@ -1,17 +1,11 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import ClipField
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
@ -149,7 +138,7 @@ class SDXLPromptInvocationBase:
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -181,13 +181,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: bool = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
|
@ -28,6 +28,7 @@ from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger: Logger = InvokeAILogger.get_logger()
|
||||
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
|
||||
"""
|
||||
|
||||
models_found: Set[Path] = Field(default_factory=set)
|
||||
scanned_dirs: Set[Path] = Field(default_factory=set)
|
||||
pruned_paths: Set[Path] = Field(default_factory=set)
|
||||
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
self.scanned_dirs = set()
|
||||
self.pruned_paths = set()
|
||||
if self.on_search_started:
|
||||
self.on_search_started(self._directory)
|
||||
|
||||
@ -139,29 +137,28 @@ class ModelSearch(ModelSearchBase):
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = Path(directory)
|
||||
if not self._directory.is_absolute():
|
||||
self._directory = self.config.models_path / self._directory
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
self._walk_directory(self._directory)
|
||||
self.search_completed()
|
||||
return self.models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str]) -> None:
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self.pruned_paths.add(Path(root))
|
||||
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
||||
continue
|
||||
|
||||
self.stats.items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path.parent in self.scanned_dirs:
|
||||
self.scanned_dirs.add(path)
|
||||
continue
|
||||
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
|
||||
absolute_path = Path(path)
|
||||
if (
|
||||
len(absolute_path.parts) - len(self._directory.parts) > max_depth
|
||||
or not absolute_path.exists()
|
||||
or absolute_path.parent in self.models_found
|
||||
):
|
||||
return
|
||||
entries = os.scandir(absolute_path.as_posix())
|
||||
entries = [entry for entry in entries if not entry.name.startswith(".")]
|
||||
dirs = [entry for entry in entries if entry.is_dir()]
|
||||
file_names = [entry.name for entry in entries if entry.is_file()]
|
||||
if any(
|
||||
(path / x).exists()
|
||||
x in file_names
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
@ -170,22 +167,23 @@ class ModelSearch(ModelSearchBase):
|
||||
"image_encoder.txt",
|
||||
]
|
||||
):
|
||||
self.scanned_dirs.add(path)
|
||||
try:
|
||||
self.model_found(path)
|
||||
self.model_found(absolute_path)
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
return
|
||||
|
||||
for n in file_names:
|
||||
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
|
||||
try:
|
||||
self.model_found(absolute_path / n)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self.scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
for d in dirs:
|
||||
self._walk_directory(absolute_path / d)
|
||||
|
@ -4,12 +4,12 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
@ -168,7 +168,7 @@ class ModelPatcher:
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
||||
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||
init_tokens_count = None
|
||||
@ -265,7 +265,7 @@ class ModelPatcher:
|
||||
@contextmanager
|
||||
def apply_clip_skip(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
||||
clip_skip: int,
|
||||
) -> None:
|
||||
skipped_layers = []
|
||||
|
Reference in New Issue
Block a user