This commit is contained in:
Sergey Borisov 2023-06-12 16:14:09 +03:00
parent 9fa78443de
commit 36eb1bd893
5 changed files with 71 additions and 23 deletions

View File

@ -64,7 +64,7 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
base_model = BaseModelType.StableDiffusion1_5 # TODO:
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
@ -116,7 +116,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Scheduler,
submodel=SubModelType.Scheduler,
),
loras=[],
),
@ -125,13 +125,13 @@ class ModelLoaderInvocation(BaseInvocation):
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Tokenizer,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.TextEncoder,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
@ -140,7 +140,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Vae,
submodel=SubModelType.Vae,
),
)
)

View File

@ -162,7 +162,8 @@ class ModelCache(object):
if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class(
model_path,
model_class,
base_model,
model_type,
)
return self.model_infos[model_info_key]
@ -208,7 +209,7 @@ class ModelCache(object):
# clean memory to make MemoryUsage() more accurate
gc.collect()
model = model_info.get_model(submodel, torch_dtype=self.precision)
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')

View File

@ -430,10 +430,10 @@ class ModelManager(object):
# vae/movq override
# TODO:
if submodel is not None and submodel in model_config:
model_path = model_config[submodel]
model_type = submodel
submodel = None
if submodel_type is not None and submodel_type in model_config:
model_path = model_config[submodel_type]
model_type = submodel_type
submodel_type = None
dst_convert_path = None # TODO:
model_path = model_class.convert_if_required(
@ -443,9 +443,11 @@ class ModelManager(object):
)
model_context = self.cache.get_model(
model_path,
model_class,
submodel_type,
model_path=model_path,
model_class=model_class,
base_model=base_model,
model_type=model_type,
submodel=submodel_type,
)
hash = "<NO_HASH>" # TODO:

View File

@ -1,3 +1,4 @@
import os
import sys
import typing
import inspect

View File

@ -102,7 +102,8 @@ class StableDiffusion15Model(DiffusersModel):
elif kwargs["format"] == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
if os.path.exists(unet_config_path):
unet_config = json.loads(unet_config_path)
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
@ -135,10 +136,14 @@ class StableDiffusion15Model(DiffusersModel):
return "checkpoint"
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
cfg = cls.build_config(**config)
if isinstance(cfg, cls.CheckpointConfig):
return _convert_ckpt_and_cache(cfg) # TODO: args
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1_5,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
@ -154,10 +159,22 @@ class StableDiffusion2BaseModel(StableDiffusion15Model):
model_type=ModelType.Pipeline,
)
class StableDiffusion2Model(DiffusersModel):
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2Base,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
class StableDiffusion2Model(StableDiffusion15Model):
# TODO: str -> Path?
# overwrite configs
# TODO: check that configs overwriten
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
@ -174,22 +191,49 @@ class StableDiffusion2Model(DiffusersModel):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline
super().__init__(
# skip StableDiffusion15Model __init__
super(StableDiffusion15Model, self).__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline,
)
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
# TODO: rework
DictConfig = dict
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
def _convert_ckpt_and_cache(
self,
version: BaseModelType,
mconfig: dict, # TODO:
in_path: str,
out_path: str,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
raise NotImplementedError()
app_config = InvokeAIAppConfig.get_config()
#if "config" not in mconfig:
# if version == BaseModelType.StableDiffusion1_5:
#if
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem