mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes
This commit is contained in:
parent
9fa78443de
commit
36eb1bd893
@ -64,7 +64,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
||||
base_model = BaseModelType.StableDiffusion1_5 # TODO:
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
@ -116,7 +116,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.Scheduler,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
@ -125,13 +125,13 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
@ -140,7 +140,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.Vae,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -162,7 +162,8 @@ class ModelCache(object):
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
model_path,
|
||||
model_class,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
return self.model_infos[model_info_key]
|
||||
@ -208,7 +209,7 @@ class ModelCache(object):
|
||||
|
||||
# clean memory to make MemoryUsage() more accurate
|
||||
gc.collect()
|
||||
model = model_info.get_model(submodel, torch_dtype=self.precision)
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
if mem_used := model_info.get_size(submodel):
|
||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||
|
||||
|
@ -430,10 +430,10 @@ class ModelManager(object):
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel is not None and submodel in model_config:
|
||||
model_path = model_config[submodel]
|
||||
model_type = submodel
|
||||
submodel = None
|
||||
if submodel_type is not None and submodel_type in model_config:
|
||||
model_path = model_config[submodel_type]
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
|
||||
dst_convert_path = None # TODO:
|
||||
model_path = model_class.convert_if_required(
|
||||
@ -443,9 +443,11 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
model_context = self.cache.get_model(
|
||||
model_path,
|
||||
model_class,
|
||||
submodel_type,
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel_type,
|
||||
)
|
||||
|
||||
hash = "<NO_HASH>" # TODO:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
|
@ -102,7 +102,8 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
elif kwargs["format"] == "diffusers":
|
||||
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
unet_config = json.loads(unet_config_path)
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
@ -135,10 +136,14 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||
cfg = cls.build_config(**config)
|
||||
if isinstance(cfg, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(cfg) # TODO: args
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1_5,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
@ -154,10 +159,22 @@ class StableDiffusion2BaseModel(StableDiffusion15Model):
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2Base,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
class StableDiffusion2Model(StableDiffusion15Model):
|
||||
|
||||
# TODO: str -> Path?
|
||||
# overwrite configs
|
||||
# TODO: check that configs overwriten
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
@ -174,22 +191,49 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
# skip StableDiffusion15Model __init__
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.Pipeline
|
||||
super().__init__(
|
||||
# skip StableDiffusion15Model __init__
|
||||
super(StableDiffusion15Model, self).__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
DictConfig = dict
|
||||
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||
def _convert_ckpt_and_cache(
|
||||
self,
|
||||
version: BaseModelType,
|
||||
mconfig: dict, # TODO:
|
||||
in_path: str,
|
||||
out_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
#if "config" not in mconfig:
|
||||
# if version == BaseModelType.StableDiffusion1_5:
|
||||
#if
|
||||
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
|
||||
|
||||
|
||||
weights = app_config.root_dir / mconfig.path
|
||||
config_file = app_config.root_dir / mconfig.config
|
||||
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
||||
|
Loading…
Reference in New Issue
Block a user