mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into build/gitignore
This commit is contained in:
commit
94740e440d
@ -9,6 +9,7 @@ from compel.prompt_parser import (Blend, Conjunction,
|
|||||||
FlattenedPrompt, Fragment)
|
FlattenedPrompt, Fragment)
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
@ -86,10 +87,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except Exception:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
#import traceback
|
#import traceback
|
||||||
# print(traceback.format_exc())
|
#print(traceback.format_exc())
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
|
@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .model_manager import ModelManager
|
from .model_manager import ModelManager
|
||||||
from .model_cache import ModelCache
|
from picklescan.scanner import scan_file_path
|
||||||
from .models import BaseModelType, ModelVariantType
|
from .models import BaseModelType, ModelVariantType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
checkpoint = load_file(checkpoint_path)
|
checkpoint = load_file(checkpoint_path)
|
||||||
else:
|
else:
|
||||||
if scan_needed:
|
if scan_needed:
|
||||||
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
# scan model
|
||||||
|
scan_result = scan_file_path(checkpoint_path)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||||
checkpoint = torch.load(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
|
@ -655,6 +655,9 @@ class TextualInversionModel:
|
|||||||
else:
|
else:
|
||||||
result.embedding = next(iter(state_dict.values()))
|
result.embedding = next(iter(state_dict.values()))
|
||||||
|
|
||||||
|
if len(result.embedding.shape) == 1:
|
||||||
|
result.embedding = result.embedding.unsqueeze(0)
|
||||||
|
|
||||||
if not isinstance(result.embedding, torch.Tensor):
|
if not isinstance(result.embedding, torch.Tensor):
|
||||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
|
|||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, SubModelType,
|
BaseModelType, ModelType, SubModelType,
|
||||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||||
ModelConfigBase,
|
ModelConfigBase, ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
@ -409,7 +409,7 @@ class ModelManager(object):
|
|||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
model_path = self.app_config.root_path / model_config.path
|
model_path = self.app_config.root_path / model_config.path
|
||||||
@ -421,7 +421,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
|
@ -2,7 +2,7 @@ import inspect
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Literal, get_origin
|
from typing import Literal, get_origin
|
||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
|
||||||
|
class ModelNotFoundException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
|
@ -8,6 +8,7 @@ from .base import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
classproperty,
|
classproperty,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||||
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
|
|||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in textual inversion")
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
|
||||||
|
checkpoint_path = self.model_path
|
||||||
|
if os.path.isdir(checkpoint_path):
|
||||||
|
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
|
||||||
|
|
||||||
|
if not os.path.exists(checkpoint_path):
|
||||||
|
raise ModelNotFoundException()
|
||||||
|
|
||||||
model = TextualInversionModelRaw.from_checkpoint(
|
model = TextualInversionModelRaw.from_checkpoint(
|
||||||
file_path=self.model_path,
|
file_path=checkpoint_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import {
|
import {
|
||||||
imageDeletionConfirmed,
|
imageDeletionConfirmed,
|
||||||
imageToDeleteCleared,
|
imageToDeleteCleared,
|
||||||
|
isModalOpenChanged,
|
||||||
selectImageUsage,
|
selectImageUsage,
|
||||||
} from '../store/imageDeletionSlice';
|
} from '../store/imageDeletionSlice';
|
||||||
|
|
||||||
@ -63,6 +64,7 @@ const DeleteImageModal = () => {
|
|||||||
|
|
||||||
const handleClose = useCallback(() => {
|
const handleClose = useCallback(() => {
|
||||||
dispatch(imageToDeleteCleared());
|
dispatch(imageToDeleteCleared());
|
||||||
|
dispatch(isModalOpenChanged(false));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
|
@ -31,6 +31,7 @@ const imageDeletion = createSlice({
|
|||||||
},
|
},
|
||||||
imageToDeleteCleared: (state) => {
|
imageToDeleteCleared: (state) => {
|
||||||
state.imageToDelete = null;
|
state.imageToDelete = null;
|
||||||
|
state.isModalOpen = false;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user