2023-05-02 02:57:30 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
from enum import Enum
|
|
|
|
from invokeai.backend.model_management.model_cache import ModelCache
|
2023-05-02 02:57:30 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
class DummyModelBase(object):
|
|
|
|
'''Base class for dummy component of a diffusers model'''
|
|
|
|
def __init__(self, repo_id):
|
|
|
|
self.repo_id = repo_id
|
|
|
|
self.device = torch.device('cpu')
|
2023-05-02 02:57:30 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
@classmethod
|
|
|
|
def from_pretrained(cls,
|
|
|
|
repo_id:str,
|
|
|
|
revision:str=None,
|
|
|
|
subfolder:str=None,
|
|
|
|
cache_dir:str=None,
|
|
|
|
):
|
|
|
|
return cls(repo_id)
|
|
|
|
|
|
|
|
def to(self, device):
|
|
|
|
self.device = device
|
2023-05-02 02:57:30 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
class DummyModelType1(DummyModelBase):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class DummyModelType2(DummyModelBase):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class DummyPipeline(DummyModelBase):
|
|
|
|
'''Dummy pipeline object is a composite of several types'''
|
|
|
|
def __init__(self,repo_id):
|
|
|
|
super().__init__(repo_id)
|
|
|
|
self.type1 = DummyModelType1('dummy/type1')
|
|
|
|
self.type2 = DummyModelType2('dummy/type2')
|
|
|
|
|
|
|
|
class DMType(Enum):
|
|
|
|
dummy_pipeline = DummyPipeline
|
|
|
|
type1 = DummyModelType1
|
|
|
|
type2 = DummyModelType2
|
|
|
|
|
2023-05-08 03:18:17 +00:00
|
|
|
cache = ModelCache(max_cache_size=4)
|
2023-05-02 02:57:30 +00:00
|
|
|
|
|
|
|
def test_pipeline_fetch():
|
2023-05-03 16:38:18 +00:00
|
|
|
assert cache.cache_size()==0
|
|
|
|
with cache.get_model('dummy/pipeline1',DMType.dummy_pipeline) as pipeline1,\
|
|
|
|
cache.get_model('dummy/pipeline1',DMType.dummy_pipeline) as pipeline1a,\
|
|
|
|
cache.get_model('dummy/pipeline2',DMType.dummy_pipeline) as pipeline2:
|
|
|
|
assert pipeline1 is not None, 'get_model() should not return None'
|
|
|
|
assert pipeline1a is not None, 'get_model() should not return None'
|
|
|
|
assert pipeline2 is not None, 'get_model() should not return None'
|
|
|
|
assert type(pipeline1)==DMType.dummy_pipeline.value,'get_model() did not return model of expected type'
|
|
|
|
assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
|
|
|
|
assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
|
2023-05-08 03:18:17 +00:00
|
|
|
assert len(cache.models)==2,'cache should uniquely cache models with same identity'
|
2023-05-03 16:38:18 +00:00
|
|
|
with cache.get_model('dummy/pipeline3',DMType.dummy_pipeline) as pipeline3,\
|
|
|
|
cache.get_model('dummy/pipeline4',DMType.dummy_pipeline) as pipeline4:
|
2023-05-08 03:18:17 +00:00
|
|
|
assert len(cache.models)==4,'cache did not grow as expected'
|
2023-05-03 16:38:18 +00:00
|
|
|
|
|
|
|
def test_signatures():
|
|
|
|
with cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main') as pipeline1,\
|
|
|
|
cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='fp16') as pipeline2,\
|
|
|
|
cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main',subfolder='foo') as pipeline3:
|
|
|
|
assert pipeline1 != pipeline2,'models are distinguished by their revision'
|
|
|
|
assert pipeline1 != pipeline3,'models are distinguished by their subfolder'
|
|
|
|
|
|
|
|
def test_pipeline_device():
|
|
|
|
with cache.get_model('dummy/pipeline1',DMType.type1) as model1:
|
|
|
|
assert model1.device==torch.device('cuda'),'when in context, model device should be in GPU'
|
|
|
|
with cache.get_model('dummy/pipeline1',DMType.type1, gpu_load=False) as model1:
|
|
|
|
assert model1.device==torch.device('cpu'),'when gpu_load=False, model device should be CPU'
|
2023-05-02 02:57:30 +00:00
|
|
|
|
|
|
|
def test_submodel_fetch():
|
2023-05-03 16:38:18 +00:00
|
|
|
with cache.get_model(repo_id_or_path='dummy/pipeline1',model_type=DMType.dummy_pipeline) as pipeline,\
|
|
|
|
cache.get_model(repo_id_or_path='dummy/pipeline1',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part1,\
|
|
|
|
cache.get_model(repo_id_or_path='dummy/pipeline2',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part2:
|
|
|
|
assert type(part1)==DummyModelType1,'returned submodel is not of expected type'
|
|
|
|
assert part1.device==torch.device('cuda'),'returned submodel should be in the GPU when in context'
|
|
|
|
assert pipeline.type1==part1,'returned submodel should match the corresponding subpart of parent model'
|
|
|
|
assert pipeline.type1!=part2,'returned submodel should not match the subpart of a different parent'
|
|
|
|
|