This commit is contained in:
Sergey Borisov
2023-06-12 16:14:09 +03:00
parent 9fa78443de
commit 36eb1bd893
5 changed files with 71 additions and 23 deletions

View File

@ -1,3 +1,4 @@
import os
import sys
import typing
import inspect

View File

@ -102,7 +102,8 @@ class StableDiffusion15Model(DiffusersModel):
elif kwargs["format"] == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
if os.path.exists(unet_config_path):
unet_config = json.loads(unet_config_path)
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
@ -135,10 +136,14 @@ class StableDiffusion15Model(DiffusersModel):
return "checkpoint"
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
cfg = cls.build_config(**config)
if isinstance(cfg, cls.CheckpointConfig):
return _convert_ckpt_and_cache(cfg) # TODO: args
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1_5,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
@ -154,10 +159,22 @@ class StableDiffusion2BaseModel(StableDiffusion15Model):
model_type=ModelType.Pipeline,
)
class StableDiffusion2Model(DiffusersModel):
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2Base,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
class StableDiffusion2Model(StableDiffusion15Model):
# TODO: str -> Path?
# overwrite configs
# TODO: check that configs overwriten
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
@ -174,22 +191,49 @@ class StableDiffusion2Model(DiffusersModel):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline
super().__init__(
# skip StableDiffusion15Model __init__
super(StableDiffusion15Model, self).__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline,
)
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
# TODO: rework
DictConfig = dict
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
def _convert_ckpt_and_cache(
self,
version: BaseModelType,
mconfig: dict, # TODO:
in_path: str,
out_path: str,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
raise NotImplementedError()
app_config = InvokeAIAppConfig.get_config()
#if "config" not in mconfig:
# if version == BaseModelType.StableDiffusion1_5:
#if
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem