mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Enable v_prediction for sd-1 models (#4674)
## What type of PR is this? (check all applicable) - [X] Feature ## Have you discussed this change with the InvokeAI team? - [X] Yes ## Have you updated all relevant documentation? - [X] Yes ## Description It turns out that there are a few SD-1 models that use the `v_prediction` SchedulerPredictionType. Examples here: https://huggingface.co/zatochu/EasyFluff/tree/main . Previously we only allowed the user to set the prediction type for sd-2 models. This PR does three things: 1. Add a new checkpoint configuration file `v1-inference-v.yaml`. This will install automatically on new installs, but for existing installs users will need to update and then run `invokeai-configure` to get it. 2. Change the prompt on the web model install page to indicate that some SD-1 models use the "v_prediction" method 3. Provide backend support for sd-1 models that use the v_prediction method. ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> - Related Issue # - Closes #4277 ## QA Instructions, Screenshots, Recordings Update, run `invoke-ai-configure --yes --skip-sd --skip-support`, and then use the web interface to install https://huggingface.co/zatochu/EasyFluff/resolve/main/EasyFluffV11.2.safetensors with the prediction type set to "v_prediction." Check that the installed model uses configuration `v1-inference-v.yaml`. If "None" is selected from the install menu, check that SD-1 models default to `v1-inference.yaml` and SD-2 default to `v2-inference-v.yaml`. Also try installing a checkpoint at a local path if a like-named config .yaml file is located next to it in the same directory. This should override everything else and use the local path .yaml. ## Added/updated tests? - [ ] Yes - [X] No
This commit is contained in:
commit
f05379f965
@ -146,7 +146,8 @@ async def update_model(
|
|||||||
async def import_model(
|
async def import_model(
|
||||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
|
||||||
|
default=None,
|
||||||
),
|
),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||||
@ -155,6 +156,8 @@ async def import_model(
|
|||||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
print(f"DEBUG: prediction_type = {prediction_type}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||||
|
@ -47,8 +47,14 @@ Config_preamble = """
|
|||||||
|
|
||||||
LEGACY_CONFIGS = {
|
LEGACY_CONFIGS = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelVariantType.Normal: "v1-inference.yaml",
|
ModelVariantType.Normal: {
|
||||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
||||||
|
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
||||||
|
},
|
||||||
|
ModelVariantType.Inpaint: {
|
||||||
|
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
|
||||||
|
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelVariantType.Normal: {
|
ModelVariantType.Normal: {
|
||||||
@ -286,7 +292,7 @@ class ModelInstall(object):
|
|||||||
location = download_with_resume(url, Path(staging))
|
location = download_with_resume(url, Path(staging))
|
||||||
if not location:
|
if not location:
|
||||||
logger.error(f"Unable to download {url}. Skipping.")
|
logger.error(f"Unable to download {url}. Skipping.")
|
||||||
info = ModelProbe().heuristic_probe(location)
|
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
models_path = shutil.move(location, dest)
|
models_path = shutil.move(location, dest)
|
||||||
@ -393,7 +399,7 @@ class ModelInstall(object):
|
|||||||
possible_conf = path.with_suffix(".yaml")
|
possible_conf = path.with_suffix(".yaml")
|
||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||||
elif info.base_type == BaseModelType.StableDiffusion2:
|
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||||
legacy_conf = Path(
|
legacy_conf = Path(
|
||||||
self.config.legacy_conf_dir,
|
self.config.legacy_conf_dir,
|
||||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
||||||
|
@ -1279,12 +1279,12 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
extract_ema = original_config["model"]["params"]["use_ema"]
|
extract_ema = original_config["model"]["params"]["use_ema"]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_version == BaseModelType.StableDiffusion2
|
model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
|
||||||
and original_config["model"]["params"].get("parameterization") == "v"
|
and original_config["model"]["params"].get("parameterization") == "v"
|
||||||
):
|
):
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
image_size = 768
|
image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512
|
||||||
else:
|
else:
|
||||||
prediction_type = "epsilon"
|
prediction_type = "epsilon"
|
||||||
upcast_attention = False
|
upcast_attention = False
|
||||||
|
@ -90,8 +90,7 @@ class ModelProbe(object):
|
|||||||
to place it somewhere in the models directory hierarchy. If the model is
|
to place it somewhere in the models directory hierarchy. If the model is
|
||||||
already loaded into memory, you may provide it as model in order to avoid
|
already loaded into memory, you may provide it as model in order to avoid
|
||||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
the path to the model and returns the SchedulerPredictionType.
|
||||||
between V2-Base and V2-768 SD models.
|
|
||||||
"""
|
"""
|
||||||
if model_path:
|
if model_path:
|
||||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
||||||
@ -305,25 +304,36 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
else:
|
else:
|
||||||
raise InvalidModelException("Cannot determine base type")
|
raise InvalidModelException("Cannot determine base type")
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||||
|
"""Return model prediction type."""
|
||||||
|
# if there is a .yaml associated with this checkpoint, then we do not need
|
||||||
|
# to probe for the prediction type as it will be ignored.
|
||||||
|
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
|
||||||
|
return None
|
||||||
|
|
||||||
type = self.get_base_type()
|
type = self.get_base_type()
|
||||||
if type == BaseModelType.StableDiffusion1:
|
if type == BaseModelType.StableDiffusion2:
|
||||||
return SchedulerPredictionType.Epsilon
|
checkpoint = self.checkpoint
|
||||||
checkpoint = self.checkpoint
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
if "global_step" in checkpoint:
|
||||||
if "global_step" in checkpoint:
|
if checkpoint["global_step"] == 220000:
|
||||||
if checkpoint["global_step"] == 220000:
|
return SchedulerPredictionType.Epsilon
|
||||||
return SchedulerPredictionType.Epsilon
|
elif checkpoint["global_step"] == 110000:
|
||||||
elif checkpoint["global_step"] == 110000:
|
return SchedulerPredictionType.VPrediction
|
||||||
return SchedulerPredictionType.VPrediction
|
if self.helper and self.checkpoint_path:
|
||||||
if (
|
if helper_guess := self.helper(self.checkpoint_path):
|
||||||
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
|
return helper_guess
|
||||||
): # if a .yaml config file exists, then this step not needed
|
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||||
return self.helper(self.checkpoint_path)
|
|
||||||
else:
|
elif type == BaseModelType.StableDiffusion1:
|
||||||
return None
|
if self.helper and self.checkpoint_path:
|
||||||
|
if helper_guess := self.helper(self.checkpoint_path):
|
||||||
|
return helper_guess
|
||||||
|
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||||
|
80
invokeai/configs/stable-diffusion/v1-inference-v.yaml
Normal file
80
invokeai/configs/stable-diffusion/v1-inference-v.yaml
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
parameterization: "v"
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ['sculpture']
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@ -574,7 +574,7 @@
|
|||||||
"onnxModels": "Onnx",
|
"onnxModels": "Onnx",
|
||||||
"pathToCustomConfig": "Path To Custom Config",
|
"pathToCustomConfig": "Path To Custom Config",
|
||||||
"pickModelType": "Pick Model Type",
|
"pickModelType": "Pick Model Type",
|
||||||
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models only)",
|
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
|
||||||
"quickAdd": "Quick Add",
|
"quickAdd": "Quick Add",
|
||||||
"repo_id": "Repo ID",
|
"repo_id": "Repo ID",
|
||||||
"repoIDValidationMsg": "Online repository of your model",
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
|
@ -655,7 +655,7 @@
|
|||||||
"onnxModels": "Onnx",
|
"onnxModels": "Onnx",
|
||||||
"pathToCustomConfig": "Path To Custom Config",
|
"pathToCustomConfig": "Path To Custom Config",
|
||||||
"pickModelType": "Pick Model Type",
|
"pickModelType": "Pick Model Type",
|
||||||
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models only)",
|
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
|
||||||
"quickAdd": "Quick Add",
|
"quickAdd": "Quick Add",
|
||||||
"repo_id": "Repo ID",
|
"repo_id": "Repo ID",
|
||||||
"repoIDValidationMsg": "Online repository of your model",
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
|
Loading…
Reference in New Issue
Block a user