Enable v_prediction for sd-1 models (#4674)

## What type of PR is this? (check all applicable)

- [X] Feature

## Have you discussed this change with the InvokeAI team?
- [X] Yes

      
## Have you updated all relevant documentation?
- [X] Yes

## Description

It turns out that there are a few SD-1 models that use the
`v_prediction` SchedulerPredictionType. Examples here:
https://huggingface.co/zatochu/EasyFluff/tree/main . Previously we only
allowed the user to set the prediction type for sd-2 models. This PR
does three things:

1. Add a new checkpoint configuration file `v1-inference-v.yaml`. This
will install automatically on new installs, but for existing installs
users will need to update and then run `invokeai-configure` to get it.
2. Change the prompt on the web model install page to indicate that some
SD-1 models use the "v_prediction" method
3. Provide backend support for sd-1 models that use the v_prediction
method.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #4277 

## QA Instructions, Screenshots, Recordings

Update, run `invoke-ai-configure --yes --skip-sd --skip-support`, and
then use the web interface to install
https://huggingface.co/zatochu/EasyFluff/resolve/main/EasyFluffV11.2.safetensors
with the prediction type set to "v_prediction." Check that the installed
model uses configuration `v1-inference-v.yaml`.

If "None" is selected from the install menu, check that SD-1 models
default to `v1-inference.yaml` and SD-2 default to
`v2-inference-v.yaml`.

Also try installing a checkpoint at a local path if a like-named config
.yaml file is located next to it in the same directory. This should
override everything else and use the local path .yaml.

## Added/updated tests?

- [ ] Yes
- [X] No
This commit is contained in:
Lincoln Stein 2023-09-24 15:24:36 -04:00 committed by GitHub
commit f05379f965
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 128 additions and 29 deletions

View File

@ -146,7 +146,8 @@ async def update_model(
async def import_model( async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body( prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
description="Prediction type for SDv2 checkpoint files", default="v_prediction" description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
default=None,
), ),
) -> ImportModelResponse: ) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
@ -155,6 +156,8 @@ async def import_model(
prediction_types = {x.value: x for x in SchedulerPredictionType} prediction_types = {x.value: x for x in SchedulerPredictionType}
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
print(f"DEBUG: prediction_type = {prediction_type}")
try: try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type) items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)

View File

@ -47,8 +47,14 @@ Config_preamble = """
LEGACY_CONFIGS = { LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml", ModelVariantType.Normal: {
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", SchedulerPredictionType.Epsilon: "v1-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
},
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: { ModelVariantType.Normal: {
@ -286,7 +292,7 @@ class ModelInstall(object):
location = download_with_resume(url, Path(staging)) location = download_with_resume(url, Path(staging))
if not location: if not location:
logger.error(f"Unable to download {url}. Skipping.") logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location) info = ModelProbe().heuristic_probe(location, self.prediction_helper)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
models_path = shutil.move(location, dest) models_path = shutil.move(location, dest)
@ -393,7 +399,7 @@ class ModelInstall(object):
possible_conf = path.with_suffix(".yaml") possible_conf = path.with_suffix(".yaml")
if possible_conf.exists(): if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf)) legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type == BaseModelType.StableDiffusion2: elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
legacy_conf = Path( legacy_conf = Path(
self.config.legacy_conf_dir, self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type], LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],

View File

@ -1279,12 +1279,12 @@ def download_from_original_stable_diffusion_ckpt(
extract_ema = original_config["model"]["params"]["use_ema"] extract_ema = original_config["model"]["params"]["use_ema"]
if ( if (
model_version == BaseModelType.StableDiffusion2 model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
and original_config["model"]["params"].get("parameterization") == "v" and original_config["model"]["params"].get("parameterization") == "v"
): ):
prediction_type = "v_prediction" prediction_type = "v_prediction"
upcast_attention = True upcast_attention = True
image_size = 768 image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512
else: else:
prediction_type = "epsilon" prediction_type = "epsilon"
upcast_attention = False upcast_attention = False

View File

@ -90,8 +90,7 @@ class ModelProbe(object):
to place it somewhere in the models directory hierarchy. If the model is to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish the path to the model and returns the SchedulerPredictionType.
between V2-Base and V2-768 SD models.
""" """
if model_path: if model_path:
format_type = "diffusers" if model_path.is_dir() else "checkpoint" format_type = "diffusers" if model_path.is_dir() else "checkpoint"
@ -305,25 +304,36 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
else: else:
raise InvalidModelException("Cannot determine base type") raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType: def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Return model prediction type."""
# if there is a .yaml associated with this checkpoint, then we do not need
# to probe for the prediction type as it will be ignored.
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
return None
type = self.get_base_type() type = self.get_base_type()
if type == BaseModelType.StableDiffusion1: if type == BaseModelType.StableDiffusion2:
return SchedulerPredictionType.Epsilon checkpoint = self.checkpoint
checkpoint = self.checkpoint state_dict = self.checkpoint.get("state_dict") or checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: if "global_step" in checkpoint:
if "global_step" in checkpoint: if checkpoint["global_step"] == 220000:
if checkpoint["global_step"] == 220000: return SchedulerPredictionType.Epsilon
return SchedulerPredictionType.Epsilon elif checkpoint["global_step"] == 110000:
elif checkpoint["global_step"] == 110000: return SchedulerPredictionType.VPrediction
return SchedulerPredictionType.VPrediction if self.helper and self.checkpoint_path:
if ( if helper_guess := self.helper(self.checkpoint_path):
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists() return helper_guess
): # if a .yaml config file exists, then this step not needed return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
return self.helper(self.checkpoint_path)
else: elif type == BaseModelType.StableDiffusion1:
return None if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase): class VaeCheckpointProbe(CheckpointProbeBase):

View File

@ -0,0 +1,80 @@
model:
base_learning_rate: 1.0e-04
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['sculpture']
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -574,7 +574,7 @@
"onnxModels": "Onnx", "onnxModels": "Onnx",
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"pickModelType": "Pick Model Type", "pickModelType": "Pick Model Type",
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models only)", "predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
"quickAdd": "Quick Add", "quickAdd": "Quick Add",
"repo_id": "Repo ID", "repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",

View File

@ -655,7 +655,7 @@
"onnxModels": "Onnx", "onnxModels": "Onnx",
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"pickModelType": "Pick Model Type", "pickModelType": "Pick Model Type",
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models only)", "predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
"quickAdd": "Quick Add", "quickAdd": "Quick Add",
"repo_id": "Repo ID", "repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",