convert add_model(), del_model(), list_models() etc to use bifurcated names

This commit is contained in:
Lincoln Stein 2023-05-13 14:44:44 -04:00
parent bc96727cbe
commit 72967bf118
6 changed files with 217 additions and 128 deletions

View File

@ -1,12 +1,9 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
import shutil
import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union
from typing import Annotated, Literal, Optional, Union
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -19,6 +16,15 @@ class VaeRepo(BaseModel):
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
@ -29,12 +35,8 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
@ -56,7 +58,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]
@models_router.get(
@ -121,7 +123,7 @@ async def delete_model(model_name: str) -> None:
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
logger.error(f"Model not found")
logger.error("Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")

View File

@ -60,14 +60,14 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
# TODO: not found exceptions
if not context.services.model_manager.valid_model(
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
):
raise Exception(f"Unkown model name: {self.model_name}!")
"""
if not context.services.model_manager.valid_model(
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.tokenizer,
@ -76,7 +76,7 @@ class ModelLoaderInvocation(BaseInvocation):
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.valid_model(
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.text_encoder,
@ -85,7 +85,7 @@ class ModelLoaderInvocation(BaseInvocation):
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.valid_model(
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.unet,

View File

@ -59,35 +59,35 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def valid_model(self, model_name: str) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
def model_exists(
self,
model_name: str,
model_type: SDModelType
) -> bool:
pass
@abstractmethod
def default_model(self) -> Union[str,None]:
def default_model(self) -> Union[Tuple(str, SDModelType),None],
"""
Returns the name of the default model, or None
Returns the name and typeof the default model, or None
if none is defined.
"""
pass
@abstractmethod
def set_default_model(self, model_name:str):
def set_default_model(self, model_name: str, model_type: SDModelType):
"""Sets the default model to the indicated name."""
pass
@abstractmethod
def model_info(self, model_name: str)->dict:
def model_info(self, model_name: str, model_type: SDModelType)->dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
pass
@abstractmethod
def model_names(self)->list[str]:
def model_names(self)->List[Tuple(str, SDModelType)]:
"""
Returns a list of all the model names known.
"""
@ -97,18 +97,25 @@ class ModelManagerServiceBase(ABC):
def list_models(self)->dict:
"""
Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
{ model_key1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
model_key2: { etc }
"""
pass
@abstractmethod
def add_model(
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
)->None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -121,7 +128,7 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def del_model(self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers,
model_type: SDModelType,
delete_files: bool = False):
"""
Delete the named model from configuration. If delete_files is true,
@ -332,33 +339,37 @@ class ModelManagerService(ModelManagerServiceBase):
return model_info
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
def model_exists(
self,
model_name: str,
model_type: SDModelType
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
return self.mgr.valid_model(
return self.mgr.model_exists(
model_name,
model_type)
def default_model(self) -> Union[str,None]:
def default_model(self) -> Union[Tuple(str, SDModelType),None],
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()
def set_default_model(self, model_name:str):
def set_default_model(self, model_name:str, model_type: SDModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name)
def model_info(self, model_name: str)->dict:
def model_info(self, model_name: str, model_type: SDModelType)->dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name)
def model_names(self)->list[str]:
def model_names(self)->List[Tuple(str, SDModelType)]:
"""
Returns a list of all the model names known.
"""
@ -367,16 +378,21 @@ class ModelManagerService(ModelManagerServiceBase):
def list_models(self)->dict:
"""
Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
{ model_key: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
"""
return self.mgr.list_models()
def add_model(
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False)->None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -384,7 +400,7 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
return self.mgr.add_model(model_name, model_attributes, dict, clobber)
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
def del_model(self,
@ -439,7 +455,7 @@ class ModelManagerService(ModelManagerServiceBase):
Creates an entry for the indicated textual inversion embedding file.
Call commit() to write out the configuration to models.yaml
"""
self.mgr(path, model_name, description)
self.mgr.import_embedding(path, model_name, description)
def heuristic_import(
self,

View File

@ -227,7 +227,7 @@ class Generate:
# don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
if not self.model_manager.valid_model(model):
if not self.model_manager.model_exists(model):
logger.warning(
f'"{model}" is not a known model name; falling back to {fallback}.'
)
@ -877,7 +877,7 @@ class Generate:
# the model cache does the loading and offloading
cache = self.model_manager
if not cache.valid_model(model_name):
if not cache.model_exists(model_name):
raise KeyError(
f'** "{model_name}" is not a known model name. Cannot change.'
)

View File

@ -9,6 +9,7 @@ return a SDModelInfo object that contains the following attributes:
model into GPU VRAM and returns the model for use.
See below for usage.
* name -- symbolic name of the model
* type -- SDModelType of the model
* hash -- unique hash for the model
* location -- path or repo_id of the model
* revision -- revision of the model if coming from a repo id,
@ -26,7 +27,7 @@ Typical usage:
max_cache_size=8
) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5')
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
with model_info.context as my_model:
my_model.latents_from_embeddings(...)
@ -54,15 +55,20 @@ MODELS.YAML
The general format of a models.yaml section is:
name-of-model:
format: diffusers|ckpt|vae|text_encoder|tokenizer...
type-of-model/name-of-model:
format: folder|ckpt|safetensors
repo_id: owner/repo
path: /path/to/local/file/or/directory
subfolder: subfolder-name
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
unet, scheduler, safety_checker, feature_extractor}, and correspond to
items in the SDModelType enum defined in model_cache.py
The type of model is given in the stanza key, and is one of
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
safety_checker, feature_extractor, lora, textual_inversion}, and
correspond to items in the SDModelType enum defined in model_cache.py
The format indicates whether the model is organized as a folder with
model subdirectories, or is contained in a single checkpoint or
safetensors file.
One, but not both, of repo_id and path are provided. repo_id is the
HuggingFace repository ID of the model, and path points to the file or
@ -74,13 +80,13 @@ the main model. These are usually named after the model type, such as
This example summarizes the two ways of getting a non-diffuser model:
clip-test-1:
format: text_encoder
text_encoder/clip-test-1:
format: folder
repo_id: openai/clip-vit-large-patch14
description: Returns standalone CLIPTextModel
clip-test-2:
format: text_encoder
text_encoder/clip-test-2:
format: folder
repo_id: stabilityai/stable-diffusion-2
subfolder: text_encoder
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
@ -101,12 +107,12 @@ You may wish to use the same name for a related family of models. To
do this, disambiguate the stanza key with the model and and format
separated by "/". Example:
clip-large/tokenizer:
tokenizer/clip-large:
format: tokenizer
repo_id: openai/clip-vit-large-patch14
description: Returns standalone tokenizer
clip-large/text_encoder:
text_encoder/clip-large:
format: text_encoder
repo_id: openai/clip-vit-large-patch14
description: Returns standalone text encoder
@ -128,31 +134,41 @@ from __future__ import annotations
import os
import re
import textwrap
from contextlib import suppress
from dataclasses import dataclass
from enum import Enum, auto
from packaging import version
from pathlib import Path
from shutil import rmtree
from typing import Union, Callable, types
from contextlib import suppress
from typing import Callable, Optional, List, Tuple, Union, types
import safetensors
import safetensors.torch
import torch
import invokeai.backend.util.logging as logger
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import (Globals, global_cache_dir,
global_resolve_path)
from invokeai.backend.util import download_with_resume
from ..util import CUDA_DEVICE
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
SilenceWarnings)
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
CONFIG_FILE_VERSION='3.0.0'
# wanted to use pydantic here, but Generator objects not supported
@dataclass
class SDModelInfo():
context: ModelLocker
name: str
type: SDModelType
hash: str
location: Union[Path,str]
precision: torch.dtype
@ -208,14 +224,17 @@ class ModelManager(object):
type and precision are set up for a CUDA system running at half precision.
"""
if isinstance(config, DictConfig):
self.config = config
self.config_path = None
self.config = config
elif isinstance(config,(str,Path)):
self.config_path = config
self.config = OmegaConf.load(self.config_path)
else:
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
# check config version number and update on disk/RAM if necessary
self._update_config_file_version()
self.cache = ModelCache(
max_cache_size=max_cache_size,
execution_device = device_type,
@ -226,8 +245,7 @@ class ModelManager(object):
self.cache_keys = dict()
self.logger = logger
# TODO: rename to smth like - is_model_exists
def valid_model(
def model_exists(
self,
model_name: str,
model_type: SDModelType = SDModelType.diffusers,
@ -246,14 +264,14 @@ class ModelManager(object):
model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__:
# TODO:
raise Exception(f"Unkown model type: {model_type_str}")
raise Exception(f"Unknown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str])
def get_model(
self,
model_name: str,
model_type: SDModelType=None,
model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None,
) -> SDModelInfo:
"""Given a model named identified in models.yaml, return
@ -311,7 +329,7 @@ class ModelManager(object):
if model_type == SDModelType.diffusers:
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
if 'weights' in mconfig:
if mconfig.format in ["ckpt", "diffusers"]:
location = self.convert_ckpt_and_cache(mconfig)
else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
@ -350,6 +368,7 @@ class ModelManager(object):
return SDModelInfo(
context = model_context,
name = model_name,
type = submodel or model_type,
hash = hash,
location = location,
revision = revision,
@ -358,43 +377,50 @@ class ModelManager(object):
_cache = self.cache
)
def default_model(self) -> Union[str,None]:
def default_model(self) -> Union[Tuple(str, SDModelType),None]:
"""
Returns the name of the default model, or None
if none is defined.
"""
for model_name in self.config:
if self.config[model_name].get("default"):
return model_name
return list(self.config.keys())[0] # first one
for model_name, model_type in self.model_names():
model_key = self.create_key(model_name, model_type)
if self.config[model_key].get("default"):
return (model_name, model_type)
return self.model_names()[0][0]
def set_default_model(self, model_name: str) -> None:
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
"""
assert model_name in self.model_names(), f"unknown model '{model_name}'"
assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'"
config = self.config
for model in config:
config[model].pop("default", None)
config[model_name]["default"] = True
for model_name, model_type in self.model_names():
key = self.create_key(model_name, model_type)
config[key].pop("default", None)
config[self.create_key(model_name, model_type)]["default"] = True
def model_info(self, model_name: str) -> dict:
def model_info(
self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers
) -> dict:
"""
Given a model name returns the OmegaConf (dict-like) object describing it.
"""
if model_name not in self.config:
if not self.exists(model_name, model_type):
return None
return self.config[model_name]
return self.config[self.create_key(model_name,model_type)]
def model_names(self) -> list[str]:
def model_names(self) -> List[Tuple(str, SDModelType)]:
"""
Return a list consisting of all the names of models defined in models.yaml
Return a list of (str, SDModelType) corresponding to all models
known to the configuration.
"""
return list(self.config.keys())
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)]
def is_legacy(self, model_name: str) -> bool:
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
"""
Return true if this is a legacy (.ckpt) model
"""
@ -402,7 +428,7 @@ class ModelManager(object):
# there are no legacy ckpts!
if Globals.ckpt_convert:
return False
info = self.model_info(model_name)
info = self.model_info(model_name, model_type)
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
return True
return False
@ -422,15 +448,17 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]:
continue
if model_key=='config_file_version':
continue
model_name, model_type = self.parse_key(model_key)
models[model_name] = dict()
models[model_key] = dict()
# TODO: return all models in future
if model_type != SDModelType.diffusers:
continue
model_format = "ckpt" if "weights" in stanza else "diffusers"
model_format = stanza.get('format')
# Common Attribs
status = self.cache.status(
@ -439,26 +467,25 @@ class ModelManager(object):
subfolder=stanza.get('subfolder')
)
description = stanza.get("description", None)
models[model_name].update(
description=description,
type=model_type,
models[model_key].update(
model_name=model_name,
model_type=model_type.name,
format=model_format,
description=description,
status=status.value
)
# Checkpoint Config Parse
if model_format == "ckpt":
models[model_name].update(
if model_format in ["ckpt","safetensors"]:
models[model_key].update(
config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)),
width = str(stanza.get("width", 512)),
height = str(stanza.get("height", 512)),
)
# Diffusers Config Parse
elif model_format == "diffusers":
elif model_format == "folder":
if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig):
vae = dict(
@ -467,7 +494,7 @@ class ModelManager(object):
subfolder = str(vae.get("subfolder", None)),
)
models[model_name].update(
models[model_key].update(
vae = vae,
repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)),
@ -479,12 +506,9 @@ class ModelManager(object):
"""
Print a table of models, their descriptions, and load status
"""
models = self.list_models()
for name in models:
if models[name]["type"] == "vae":
continue
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
if models[name]["status"] == "active":
for model_key, model_info in self.list_models().items():
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m"
print(line)
@ -511,9 +535,9 @@ class ModelManager(object):
# self.stack.remove(model_name)
if delete_files:
repo_id = conf.get("repo_id", None)
path = self._abs_path(conf.get("path", None))
weights = self._abs_path(conf.get("weights", None))
repo_id = model_cfg.get("repo_id", None)
path = self._abs_path(model_cfg.get("path", None))
weights = self._abs_path(model_cfg.get("weights", None))
if "weights" in model_cfg:
weights = self._abs_path(model_cfg["weights"])
self.logger.info(f"Deleting file {weights}")
@ -558,7 +582,7 @@ class ModelManager(object):
), 'model must have either the "path" or "repo_id" fields defined'
elif model_format == "ckpt":
for field in ("description", "weights", "height", "width", "config"):
for field in ("description", "weights", "config"):
assert field in model_attributes, f"required field {field} is missing"
else:
@ -579,11 +603,6 @@ class ModelManager(object):
self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key]
def import_diffuser_model(
self,
repo_or_path: Union[str, Path],
@ -604,7 +623,6 @@ class ModelManager(object):
models.yaml file.
"""
model_name = model_name or Path(repo_or_path).stem
model_key = f'{model_name}/diffusers'
model_description = description or f"Imported diffusers model {model_name}"
new_config = dict(
description=model_description,
@ -616,10 +634,10 @@ class ModelManager(object):
else:
new_config.update(repo_id=repo_or_path)
self.add_model(model_key, new_config, True)
self.add_model(model_name, SDModelType.diffusers, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
return model_key
return self.create_key(model_name, SDModelType.diffusers)
def import_lora(
self,
@ -635,7 +653,8 @@ class ModelManager(object):
model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}"
self.add_model(
f'{model_name}/{SDModelType.lora.name}',
model_name,
SDModelType.lora,
dict(
format="lora",
weights=str(path),
@ -663,7 +682,8 @@ class ModelManager(object):
model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}"
self.add_model(
f'{model_name}/{SDModelType.textual_inversion.name}',
model_name,
SDModelType.textual_inversion,
dict(
format="textual_inversion",
weights=str(weights),
@ -1025,9 +1045,14 @@ class ModelManager(object):
description=model_description,
format="diffusers",
)
if model_name in self.config:
self.del_model(model_name)
self.add_model(model_name, new_config, True)
if self.model_exists(model_name, SDModelType.diffusers):
self.del_model(model_name, SDModelType.diffusers)
self.add_model(
model_name,
SDModelType.diffusers,
new_config,
True
)
if commit_to_conf:
self.commit(commit_to_conf)
self.logger.debug("Conversion succeeded")
@ -1081,11 +1106,6 @@ class ModelManager(object):
"""\
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
)
@ -1130,3 +1150,53 @@ class ModelManager(object):
source = os.path.join(Globals.root, source)
resolved_path = Path(source)
return resolved_path
def _update_config_file_version(self):
"""
This gets called at object init time and will update
from older versions of the config file to new ones
as necessary.
"""
current_version = self.config.get("config_file_version","1.0.0")
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
self.logger.info(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
new_config = OmegaConf.create()
new_config["config_file_version"] = CONFIG_FILE_VERSION
for model_key in self.config:
old_stanza = self.config[model_key]
# ignore old and ugly way of associating a legacy
# vae with a legacy checkpont model
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
continue
# bare keys are updated to be prefixed with 'diffusers/'
if '/' not in model_key:
new_key = f'diffusers/{model_key}'
else:
new_key = model_key
if old_stanza.get('format')=='diffusers':
model_format = 'folder'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
model_format = 'ckpt'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
model_format = 'safetensors'
# copy fields over manually rather than doing a copy() or deepcopy()
# in order to avoid bringing in unwanted fields.
new_config[new_key] = dict(
description = old_stanza.get('description'),
format = model_format,
)
for field in ["repo_id", "path", "weights", "config", "vae"]:
if field_value := old_stanza.get(field):
new_config[new_key].update({field: field_value})
self.config = new_config
if self.config_path:
self.commit()

View File

@ -345,13 +345,14 @@ class Completer(object):
partial = text
matches = list()
for s in self.models:
name = self.models[s]["model_name"]
format = self.models[s]["format"]
if format == "vae":
continue
if ckpt_only and format != "ckpt":
continue
if s.startswith(partial):
matches.append(switch + s)
if name.startswith(partial):
matches.append(switch + name)
matches.sort()
return matches