adjusted regression tests to work with new SDModelTypes

This commit is contained in:
Lincoln Stein 2023-05-13 22:29:33 -04:00
parent baf5451fa0
commit 426f4eaf7e
2 changed files with 23 additions and 19 deletions

View File

@ -2,10 +2,8 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import diffusers
import einops import einops
import torch import torch
from diffusers import DiffusionPipeline
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -22,18 +20,16 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import (
import numpy as np BaseInvocation, BaseInvocationOutput,
InvocationContext, InvocationConfig
)
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
from .image import ImageField, ImageOutput, build_image_output from .image import ImageField, ImageOutput, build_image_output
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
from ...backend.model_management import SDModelType
class LatentsField(BaseModel): class LatentsField(BaseModel):
@ -213,7 +209,7 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct, h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct, v_symmetry_time_pct=None#v_symmetry_time_pct,
), ),
).add_scheduler_args_if_applicable(scheduler, eta=None)#ddim_eta) ).add_scheduler_args_if_applicable(scheduler, eta=0.0)#ddim_eta)
return conditioning_data return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:

View File

@ -2,7 +2,7 @@ import pytest
import torch import torch
from enum import Enum from enum import Enum
from invokeai.backend.model_management.model_cache import ModelCache from invokeai.backend.model_management.model_cache import ModelCache, MODEL_CLASSES
class DummyModelBase(object): class DummyModelBase(object):
'''Base class for dummy component of a diffusers model''' '''Base class for dummy component of a diffusers model'''
@ -32,13 +32,21 @@ class DummyPipeline(DummyModelBase):
'''Dummy pipeline object is a composite of several types''' '''Dummy pipeline object is a composite of several types'''
def __init__(self,repo_id): def __init__(self,repo_id):
super().__init__(repo_id) super().__init__(repo_id)
self.type1 = DummyModelType1('dummy/type1') self.dummy_model_type1 = DummyModelType1('dummy/type1')
self.type2 = DummyModelType2('dummy/type2') self.dummy_model_type2 = DummyModelType2('dummy/type2')
class DMType(Enum): class DMType(str, Enum):
dummy_pipeline = DummyPipeline dummy_pipeline = 'dummy_pipeline'
type1 = DummyModelType1 type1 = 'dummy_model_type1'
type2 = DummyModelType2 type2 = 'dummy_model_type2'
MODEL_CLASSES.update(
{
DMType.dummy_pipeline: DummyPipeline,
DMType.type1: DummyModelType1,
DMType.type2: DummyModelType2,
}
)
cache = ModelCache(max_cache_size=4) cache = ModelCache(max_cache_size=4)
@ -50,7 +58,7 @@ def test_pipeline_fetch():
assert pipeline1 is not None, 'get_model() should not return None' assert pipeline1 is not None, 'get_model() should not return None'
assert pipeline1a is not None, 'get_model() should not return None' assert pipeline1a is not None, 'get_model() should not return None'
assert pipeline2 is not None, 'get_model() should not return None' assert pipeline2 is not None, 'get_model() should not return None'
assert type(pipeline1)==DMType.dummy_pipeline.value,'get_model() did not return model of expected type' assert type(pipeline1)==DummyPipeline,'get_model() did not return model of expected type'
assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same' assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same' assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
assert len(cache.models)==2,'cache should uniquely cache models with same identity' assert len(cache.models)==2,'cache should uniquely cache models with same identity'
@ -77,6 +85,6 @@ def test_submodel_fetch():
cache.get_model(repo_id_or_path='dummy/pipeline2',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part2: cache.get_model(repo_id_or_path='dummy/pipeline2',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part2:
assert type(part1)==DummyModelType1,'returned submodel is not of expected type' assert type(part1)==DummyModelType1,'returned submodel is not of expected type'
assert part1.device==torch.device('cuda'),'returned submodel should be in the GPU when in context' assert part1.device==torch.device('cuda'),'returned submodel should be in the GPU when in context'
assert pipeline.type1==part1,'returned submodel should match the corresponding subpart of parent model' assert pipeline.dummy_model_type1==part1,'returned submodel should match the corresponding subpart of parent model'
assert pipeline.type1!=part2,'returned submodel should not match the subpart of a different parent' assert pipeline.dummy_model_type1!=part2,'returned submodel should not match the subpart of a different parent'