mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
|
@ -102,7 +102,8 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
elif kwargs["format"] == "diffusers":
|
||||
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
unet_config = json.loads(unet_config_path)
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
@ -135,10 +136,14 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||
cfg = cls.build_config(**config)
|
||||
if isinstance(cfg, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(cfg) # TODO: args
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1_5,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
@ -154,10 +159,22 @@ class StableDiffusion2BaseModel(StableDiffusion15Model):
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2Base,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
class StableDiffusion2Model(StableDiffusion15Model):
|
||||
|
||||
# TODO: str -> Path?
|
||||
# overwrite configs
|
||||
# TODO: check that configs overwriten
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
@ -174,22 +191,49 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
# skip StableDiffusion15Model __init__
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.Pipeline
|
||||
super().__init__(
|
||||
# skip StableDiffusion15Model __init__
|
||||
super(StableDiffusion15Model, self).__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
DictConfig = dict
|
||||
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||
def _convert_ckpt_and_cache(
|
||||
self,
|
||||
version: BaseModelType,
|
||||
mconfig: dict, # TODO:
|
||||
in_path: str,
|
||||
out_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
#if "config" not in mconfig:
|
||||
# if version == BaseModelType.StableDiffusion1_5:
|
||||
#if
|
||||
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
|
||||
|
||||
|
||||
weights = app_config.root_dir / mconfig.path
|
||||
config_file = app_config.root_dir / mconfig.config
|
||||
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
||||
|
Reference in New Issue
Block a user