mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
adjusted regression tests to work with new SDModelTypes
This commit is contained in:
parent
baf5451fa0
commit
426f4eaf7e
@ -2,10 +2,8 @@
|
|||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -22,18 +20,16 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||||
PostprocessingSettings
|
PostprocessingSettings
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import (
|
||||||
import numpy as np
|
BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationContext, InvocationConfig
|
||||||
|
)
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .image import ImageField, ImageOutput, build_image_output
|
from .image import ImageField, ImageOutput, build_image_output
|
||||||
|
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
from ...backend.model_management import SDModelType
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
@ -213,7 +209,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(scheduler, eta=None)#ddim_eta)
|
).add_scheduler_args_if_applicable(scheduler, eta=0.0)#ddim_eta)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||||
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from invokeai.backend.model_management.model_cache import ModelCache
|
from invokeai.backend.model_management.model_cache import ModelCache, MODEL_CLASSES
|
||||||
|
|
||||||
class DummyModelBase(object):
|
class DummyModelBase(object):
|
||||||
'''Base class for dummy component of a diffusers model'''
|
'''Base class for dummy component of a diffusers model'''
|
||||||
@ -32,13 +32,21 @@ class DummyPipeline(DummyModelBase):
|
|||||||
'''Dummy pipeline object is a composite of several types'''
|
'''Dummy pipeline object is a composite of several types'''
|
||||||
def __init__(self,repo_id):
|
def __init__(self,repo_id):
|
||||||
super().__init__(repo_id)
|
super().__init__(repo_id)
|
||||||
self.type1 = DummyModelType1('dummy/type1')
|
self.dummy_model_type1 = DummyModelType1('dummy/type1')
|
||||||
self.type2 = DummyModelType2('dummy/type2')
|
self.dummy_model_type2 = DummyModelType2('dummy/type2')
|
||||||
|
|
||||||
class DMType(Enum):
|
class DMType(str, Enum):
|
||||||
dummy_pipeline = DummyPipeline
|
dummy_pipeline = 'dummy_pipeline'
|
||||||
type1 = DummyModelType1
|
type1 = 'dummy_model_type1'
|
||||||
type2 = DummyModelType2
|
type2 = 'dummy_model_type2'
|
||||||
|
|
||||||
|
MODEL_CLASSES.update(
|
||||||
|
{
|
||||||
|
DMType.dummy_pipeline: DummyPipeline,
|
||||||
|
DMType.type1: DummyModelType1,
|
||||||
|
DMType.type2: DummyModelType2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
cache = ModelCache(max_cache_size=4)
|
cache = ModelCache(max_cache_size=4)
|
||||||
|
|
||||||
@ -50,7 +58,7 @@ def test_pipeline_fetch():
|
|||||||
assert pipeline1 is not None, 'get_model() should not return None'
|
assert pipeline1 is not None, 'get_model() should not return None'
|
||||||
assert pipeline1a is not None, 'get_model() should not return None'
|
assert pipeline1a is not None, 'get_model() should not return None'
|
||||||
assert pipeline2 is not None, 'get_model() should not return None'
|
assert pipeline2 is not None, 'get_model() should not return None'
|
||||||
assert type(pipeline1)==DMType.dummy_pipeline.value,'get_model() did not return model of expected type'
|
assert type(pipeline1)==DummyPipeline,'get_model() did not return model of expected type'
|
||||||
assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
|
assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
|
||||||
assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
|
assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
|
||||||
assert len(cache.models)==2,'cache should uniquely cache models with same identity'
|
assert len(cache.models)==2,'cache should uniquely cache models with same identity'
|
||||||
@ -77,6 +85,6 @@ def test_submodel_fetch():
|
|||||||
cache.get_model(repo_id_or_path='dummy/pipeline2',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part2:
|
cache.get_model(repo_id_or_path='dummy/pipeline2',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part2:
|
||||||
assert type(part1)==DummyModelType1,'returned submodel is not of expected type'
|
assert type(part1)==DummyModelType1,'returned submodel is not of expected type'
|
||||||
assert part1.device==torch.device('cuda'),'returned submodel should be in the GPU when in context'
|
assert part1.device==torch.device('cuda'),'returned submodel should be in the GPU when in context'
|
||||||
assert pipeline.type1==part1,'returned submodel should match the corresponding subpart of parent model'
|
assert pipeline.dummy_model_type1==part1,'returned submodel should match the corresponding subpart of parent model'
|
||||||
assert pipeline.type1!=part2,'returned submodel should not match the subpart of a different parent'
|
assert pipeline.dummy_model_type1!=part2,'returned submodel should not match the subpart of a different parent'
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user