Merge branch 'main' into ti-ui

This commit is contained in:
Lincoln Stein 2023-07-05 16:57:31 -04:00 committed by GitHub
commit 71dad6d404
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 45 additions and 25 deletions

2
.gitignore vendored
View File

@ -201,8 +201,6 @@ checkpoints
# If it's a Mac # If it's a Mac
.DS_Store .DS_Store
invokeai/frontend/web/dist/*
# Let the frontend manage its own gitignore # Let the frontend manage its own gitignore
!invokeai/frontend/web/* !invokeai/frontend/web/*

View File

@ -9,6 +9,7 @@ from compel.prompt_parser import (Blend, Conjunction,
FlattenedPrompt, Fragment) FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
@ -86,10 +87,10 @@ class CompelInvocation(BaseInvocation):
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
).context.model ).context.model
) )
except Exception: except ModelNotFoundException:
# print(e) # print(e)
#import traceback #import traceback
# print(traceback.format_exc()) #print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found") print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\

View File

@ -228,6 +228,7 @@ class MigrateTo3(object):
self._migrate_pretrained(CLIPTextModel, self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id, repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14', dest = target_dir / 'clip-vit-large-patch14',
force = True,
**kwargs) **kwargs)
# sd-2 # sd-2
@ -291,21 +292,21 @@ class MigrateTo3(object):
def _model_probe_to_path(self, info: ModelProbeInfo)->Path: def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value) return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs): def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
if dest.exists(): if dest.exists() and not force:
logger.info(f'Skipping existing {dest}') logger.info(f'Skipping existing {dest}')
return return
model = model_class.from_pretrained(repo_id, **kwargs) model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest) self._save_pretrained(model, dest, overwrite=force)
def _save_pretrained(self, model, dest: Path): def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model_name = dest.name model_name = dest.name
download_path = dest.with_name(f'{model_name}.downloading') if overwrite:
model.save_pretrained(download_path, safe_serialization=True) model.save_pretrained(dest, safe_serialization=True)
download_path.replace(dest) else:
download_path = dest.with_name(f'{model_name}.downloading')
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder:str=None)->Path: def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
@ -573,8 +574,10 @@ script, which will perform a full upgrade in place."""
dest_directory = args.dest_directory dest_directory = args.dest_directory
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory" assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." # TODO: revisit
# assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
# assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
do_migrate(root_directory,dest_directory) do_migrate(root_directory,dest_directory)

View File

@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager from .model_manager import ModelManager
from .model_cache import ModelCache from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelVariantType from .models import BaseModelType, ModelVariantType
try: try:
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint = load_file(checkpoint_path) checkpoint = load_file(checkpoint_path)
else: else:
if scan_needed: if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path) # scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
# sometimes there is a state_dict key and sometimes not # sometimes there is a state_dict key and sometimes not

View File

@ -653,6 +653,9 @@ class TextualInversionModel:
else: else:
result.embedding = next(iter(state_dict.values())) result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor): if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}") raise ValueError(f"Invalid embeddings file: {file_path.name}")

View File

@ -100,8 +100,6 @@ class ModelCache(object):
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
''' '''
#max_cache_size = 9999 #max_cache_size = 9999
execution_device = torch.device('cuda')
self.model_infos: Dict[str, ModelBase] = dict() self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload #self.sequential_offload: bool=sequential_offload

View File

@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase, ModelConfigBase, ModelNotFoundException,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -409,7 +409,7 @@ class ModelManager(object):
if model_key not in self.models: if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type) self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models: if model_key not in self.models:
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self.models[model_key]
model_path = self.app_config.root_path / model_config.path model_path = self.app_config.root_path / model_config.path
@ -421,7 +421,7 @@ class ModelManager(object):
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override # vae/movq override
# TODO: # TODO:

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"

View File

@ -8,6 +8,7 @@ from .base import (
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in textual inversion") raise Exception("There is no child models in textual inversion")
checkpoint_path = self.model_path
if os.path.isdir(checkpoint_path):
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
if not os.path.exists(checkpoint_path):
raise ModelNotFoundException()
model = TextualInversionModelRaw.from_checkpoint( model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path, file_path=checkpoint_path,
dtype=torch_dtype, dtype=torch_dtype,
) )

View File

@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store';
import { import {
imageDeletionConfirmed, imageDeletionConfirmed,
imageToDeleteCleared, imageToDeleteCleared,
isModalOpenChanged,
selectImageUsage, selectImageUsage,
} from '../store/imageDeletionSlice'; } from '../store/imageDeletionSlice';
@ -63,6 +64,7 @@ const DeleteImageModal = () => {
const handleClose = useCallback(() => { const handleClose = useCallback(() => {
dispatch(imageToDeleteCleared()); dispatch(imageToDeleteCleared());
dispatch(isModalOpenChanged(false));
}, [dispatch]); }, [dispatch]);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {

View File

@ -31,6 +31,7 @@ const imageDeletion = createSlice({
}, },
imageToDeleteCleared: (state) => { imageToDeleteCleared: (state) => {
state.imageToDelete = null; state.imageToDelete = null;
state.isModalOpen = false;
}, },
}, },
}); });