restore ability to convert merged inpaint .safetensors files

This commit is contained in:
Lincoln Stein 2023-07-30 18:20:12 -04:00 committed by psychedelicious
parent b3b94b5a8d
commit eeef1e08f8
2 changed files with 7 additions and 3 deletions

View File

@ -292,8 +292,9 @@ class DiffusersModel(ModelBase):
) )
break break
except Exception as e: except Exception as e:
# print("====ERR LOAD====") if not str(e).startswith('Error no file'):
# print(f"{variant}: {e}") print("====ERR LOAD====")
print(f"{variant}: {e}")
pass pass
else: else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")

View File

@ -4,6 +4,7 @@ from enum import Enum
from pydantic import Field from pydantic import Field
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from .base import ( from .base import (
ModelConfigBase, ModelConfigBase,
BaseModelType, BaseModelType,
@ -21,7 +22,6 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum): class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
Diffusers = "diffusers" Diffusers = "diffusers"
@ -263,6 +263,8 @@ def _convert_ckpt_and_cache(
weights = app_config.models_path / model_config.path weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config config_file = app_config.root_path / model_config.config
output_path = Path(output_path) output_path = Path(output_path)
variant = model_config.variant
pipeline_class = StableDiffusionInpaintPipeline if variant=='inpaint' else StableDiffusionPipeline
# return cached version if it exists # return cached version if it exists
if output_path.exists(): if output_path.exists():
@ -289,6 +291,7 @@ def _convert_ckpt_and_cache(
original_config_file=config_file, original_config_file=config_file,
extract_ema=True, extract_ema=True,
scan_needed=True, scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights.suffix == ".safetensors", from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()), precision=torch_dtype(choose_torch_device()),
**kwargs, **kwargs,