Merge branch 'main' into mm-ui

This commit is contained in:
blessedcoolant
2023-07-14 15:46:53 +12:00
160 changed files with 4147 additions and 4892 deletions

View File

@ -121,8 +121,8 @@ class ModelInstall(object):
installed_models = self.mgr.list_models()
for md in installed_models:
base = md['base_model']
model_type = md['type']
name = md['name']
model_type = md['model_type']
name = md['model_name']
key = ModelManager.create_key(name, base, model_type)
if key in model_dict:
model_dict[key].installed = True

View File

@ -538,9 +538,9 @@ class ModelManager(object):
model_dict = dict(
**model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase
name=cur_model_name,
model_name=cur_model_name,
base_model=cur_base_model,
type=cur_model_type,
model_type=cur_model_type,
)
models.append(model_dict)

View File

@ -37,9 +37,9 @@ MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
name: str
model_name: str
base_model: BaseModelType
type: ModelType
model_type: ModelType
for base_model, models in MODEL_CLASSES.items():
@ -56,7 +56,7 @@ for base_model, models in MODEL_CLASSES.items():
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
model_type=Literal[model_type.value],
),
))

View File

@ -127,7 +127,7 @@ class AddsMaskGuidance:
def _t_for_field(self, field_name: str, t):
if field_name == "pred_original_sample":
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
return self.scheduler.timesteps[-1]
return t
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
@ -631,7 +631,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
encoder_hidden_states = conditioning_data.text_embeddings
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])