Fix conflict resolve, add model configs to type annotation

This commit is contained in:
Sergey Borisov 2023-06-14 00:26:37 +03:00
parent c9ae26a176
commit 26090011c4
3 changed files with 39 additions and 27 deletions

View File

@ -7,6 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -60,7 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]

View File

@ -3,6 +3,8 @@
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
from pydantic import BaseModel, Field, validator
import torch
from diffusers import ControlNetModel
@ -235,11 +237,12 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
configure_model_padding(
unet,
self.seamless,
self.seamless_axes,
)
# TODO:
#configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
#)
class FakeVae:
class FakeVaeConfig:
@ -261,13 +264,15 @@ class TextToLatentsInvocation(BaseInvocation):
precision="float16" if unet.dtype == torch.float16 else "float32",
)
def prep_control_data(self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
def prep_control_data(
self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
@ -362,10 +367,12 @@ class TextToLatentsInvocation(BaseInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size
@ -434,11 +441,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(

View File

@ -33,10 +33,12 @@ MODEL_CLASSES = {
#},
}
# TODO: check with openapi annotation
def get_all_model_configs():
configs = []
configs = set()
for models in MODEL_CLASSES.values():
for model in models.values():
configs.extend(model._get_configs())
return configs
for type, model in models.items():
if type == ModelType.ControlNet:
continue # TODO:
configs.update(model._get_configs().values())
configs.discard(None)
return list(configs) # TODO: set, list or tuple