This commit is contained in:
Sergey Borisov 2023-06-12 16:14:09 +03:00
parent 9fa78443de
commit 36eb1bd893
5 changed files with 71 additions and 23 deletions

View File

@ -64,7 +64,7 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO: base_model = BaseModelType.StableDiffusion1_5 # TODO:
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
@ -116,7 +116,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_name=self.model_name, model_name=self.model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
submodel=SDModelType.Scheduler, submodel=SubModelType.Scheduler,
), ),
loras=[], loras=[],
), ),
@ -125,13 +125,13 @@ class ModelLoaderInvocation(BaseInvocation):
model_name=self.model_name, model_name=self.model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
submodel=SDModelType.Tokenizer, submodel=SubModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
submodel=SDModelType.TextEncoder, submodel=SubModelType.TextEncoder,
), ),
loras=[], loras=[],
), ),
@ -140,7 +140,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_name=self.model_name, model_name=self.model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
submodel=SDModelType.Vae, submodel=SubModelType.Vae,
), ),
) )
) )

View File

@ -162,7 +162,8 @@ class ModelCache(object):
if model_info_key not in self.model_infos: if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class( self.model_infos[model_info_key] = model_class(
model_path, model_path,
model_class, base_model,
model_type,
) )
return self.model_infos[model_info_key] return self.model_infos[model_info_key]
@ -208,7 +209,7 @@ class ModelCache(object):
# clean memory to make MemoryUsage() more accurate # clean memory to make MemoryUsage() more accurate
gc.collect() gc.collect()
model = model_info.get_model(submodel, torch_dtype=self.precision) model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel): if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')

View File

@ -430,10 +430,10 @@ class ModelManager(object):
# vae/movq override # vae/movq override
# TODO: # TODO:
if submodel is not None and submodel in model_config: if submodel_type is not None and submodel_type in model_config:
model_path = model_config[submodel] model_path = model_config[submodel_type]
model_type = submodel model_type = submodel_type
submodel = None submodel_type = None
dst_convert_path = None # TODO: dst_convert_path = None # TODO:
model_path = model_class.convert_if_required( model_path = model_class.convert_if_required(
@ -443,9 +443,11 @@ class ModelManager(object):
) )
model_context = self.cache.get_model( model_context = self.cache.get_model(
model_path, model_path=model_path,
model_class, model_class=model_class,
submodel_type, base_model=base_model,
model_type=model_type,
submodel=submodel_type,
) )
hash = "<NO_HASH>" # TODO: hash = "<NO_HASH>" # TODO:

View File

@ -1,3 +1,4 @@
import os
import sys import sys
import typing import typing
import inspect import inspect

View File

@ -102,7 +102,8 @@ class StableDiffusion15Model(DiffusersModel):
elif kwargs["format"] == "diffusers": elif kwargs["format"] == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json") unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
unet_config = json.loads(unet_config_path) with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels'] in_channels = unet_config['in_channels']
else: else:
@ -135,10 +136,14 @@ class StableDiffusion15Model(DiffusersModel):
return "checkpoint" return "checkpoint"
@classmethod @classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str: def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
cfg = cls.build_config(**config) if isinstance(config, cls.CheckpointConfig):
if isinstance(cfg, cls.CheckpointConfig): return _convert_ckpt_and_cache(
return _convert_ckpt_and_cache(cfg) # TODO: args version=BaseModelType.StableDiffusion1_5,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else: else:
return model_path return model_path
@ -154,10 +159,22 @@ class StableDiffusion2BaseModel(StableDiffusion15Model):
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
) )
class StableDiffusion2Model(DiffusersModel): @classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2Base,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
class StableDiffusion2Model(StableDiffusion15Model):
# TODO: str -> Path? # TODO: str -> Path?
# overwrite configs # TODO: check that configs overwriten
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
@ -174,22 +191,49 @@ class StableDiffusion2Model(DiffusersModel):
# skip StableDiffusion15Model __init__ # skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2 assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline assert model_type == ModelType.Pipeline
super().__init__( # skip StableDiffusion15Model __init__
super(StableDiffusion15Model, self).__init__(
model_path=model_path, model_path=model_path,
base_model=BaseModelType.StableDiffusion2, base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
) )
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
# TODO: rework # TODO: rework
DictConfig = dict DictConfig = dict
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str: def _convert_ckpt_and_cache(
self,
version: BaseModelType,
mconfig: dict, # TODO:
in_path: str,
out_path: str,
) -> str:
""" """
Convert the checkpoint model indicated in mconfig into a Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path. file. If already on disk then just returns Path.
""" """
raise NotImplementedError()
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
#if "config" not in mconfig:
# if version == BaseModelType.StableDiffusion1_5:
#if
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
weights = app_config.root_dir / mconfig.path weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem diffusers_path = app_config.converted_ckpts_dir / weights.stem