mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
remove redundant prediction_type and attention_upscaling flags
This commit is contained in:
parent
466ec3ab5e
commit
539d1f3bde
@ -631,8 +631,8 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with open(root / 'invokeai.yaml','w') as f:
|
# with open(root / 'invokeai.yaml','w') as f:
|
||||||
f.write('#empty invokeai.yaml initialization file')
|
# f.write('#empty invokeai.yaml initialization file')
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def run_console_ui(
|
def run_console_ui(
|
||||||
|
@ -3,8 +3,6 @@ Utility (backend) functions used by model_install.py
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass,field
|
from dataclasses import dataclass,field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -12,10 +10,9 @@ from tempfile import TemporaryDirectory
|
|||||||
from typing import List, Dict, Callable, Union, Set
|
from typing import List, Dict, Callable, Union, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from diffusers import AutoencoderKL, StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
@ -24,7 +21,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
|
||||||
from ..util.logging import InvokeAILogger
|
from ..util.logging import InvokeAILogger
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
@ -290,7 +286,7 @@ class ModelInstall(object):
|
|||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
break
|
break
|
||||||
elif f'learned_embeds.{suffix}' in files:
|
elif f'learned_embeds.{suffix}' in files:
|
||||||
location = self._download_hf_model(repo_id, [f'learned_embeds.suffix'], staging)
|
location = self._download_hf_model(repo_id, ['learned_embeds.suffix'], staging)
|
||||||
break
|
break
|
||||||
if not location:
|
if not location:
|
||||||
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
||||||
@ -307,7 +303,6 @@ class ModelInstall(object):
|
|||||||
self._install_path(dest, info)
|
self._install_path(dest, info)
|
||||||
|
|
||||||
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
||||||
|
|
||||||
# convoluted way to retrieve the description from datasets
|
# convoluted way to retrieve the description from datasets
|
||||||
description = f'{info.base_type.value} {info.model_type.value} model'
|
description = f'{info.base_type.value} {info.model_type.value} model'
|
||||||
if key := self.reverse_paths.get(self.current_id):
|
if key := self.reverse_paths.get(self.current_id):
|
||||||
@ -320,18 +315,7 @@ class ModelInstall(object):
|
|||||||
model_format = info.format,
|
model_format = info.format,
|
||||||
)
|
)
|
||||||
if info.model_type == ModelType.Pipeline:
|
if info.model_type == ModelType.Pipeline:
|
||||||
attributes.update(
|
attributes.update(dict(variant = info.variant_type,))
|
||||||
dict(
|
|
||||||
variant = info.variant_type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if info.base_type == BaseModelType.StableDiffusion2:
|
|
||||||
attributes.update(
|
|
||||||
dict(
|
|
||||||
prediction_type = info.prediction_type,
|
|
||||||
upcast_attention = info.prediction_type == SchedulerPredictionType.VPrediction,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if info.format=="checkpoint":
|
if info.format=="checkpoint":
|
||||||
try:
|
try:
|
||||||
legacy_conf = LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type] if info.base_type == BaseModelType.StableDiffusion2 \
|
legacy_conf = LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type] if info.base_type == BaseModelType.StableDiffusion2 \
|
||||||
|
@ -131,17 +131,12 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
|
||||||
upcast_attention: bool
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: str
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
|
||||||
upcast_attention: bool
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert base_model == BaseModelType.StableDiffusion2
|
assert base_model == BaseModelType.StableDiffusion2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user