Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 64 additions & 59 deletions tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import pytest
import torch

from diffusers import AutoencoderKLKVAEVideo
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAEVideo
main_input_name = "sample"
base_precision = 1e-2
def _run_nondeterministic(fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for tests that do backward passes.
torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)


class AutoencoderKLKVAEVideoTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLKVAEVideo

@property
def main_input_name(self) -> str:
return "sample"

@property
def output_shape(self) -> tuple:
return (3, 3, 16, 16)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_autoencoder_kl_kvae_video_config(self):
def get_init_dict(self) -> dict:
return {
"ch": 32,
"ch_mult": (1, 2),
Expand All @@ -41,78 +65,59 @@ def get_autoencoder_kl_kvae_video_config(self):
"temporal_compress_times": 2,
}

@property
def dummy_input(self):
def get_dummy_inputs(self) -> dict:
batch_size = 2
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
num_channels = 3
sizes = (16, 16)

video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)

video = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
return {"sample": video}

@property
def input_shape(self):
return (3, 3, 16, 16)

@property
def output_shape(self):
return (3, 3, 16, 16)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAECachedEncoder3D",
"KVAECachedDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestAutoencoderKLKVAEVideo(AutoencoderKLKVAEVideoTesterConfig, ModelTesterMixin):
base_precision = 1e-2

@unittest.skip("Unsupported test.")
@pytest.mark.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
super().test_outputs_equivalence()

@unittest.skip(
@pytest.mark.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_model_parallelism(self):
pass
super().test_model_parallelism()

@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_sharded_checkpoints_device_map(self):
pass

def _run_nondeterministic(self, fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for training tests that do backward passes.
import torch
class TestAutoencoderKLKVAEVideoTraining(AutoencoderKLKVAEVideoTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLKVAEVideo."""

torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"KVAECachedEncoder3D", "KVAECachedDecoder3D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_training(self):
self._run_nondeterministic(super().test_training)
_run_nondeterministic(super().test_training)

def test_ema_training(self):
self._run_nondeterministic(super().test_ema_training)
def test_training_with_ema(self):
_run_nondeterministic(super().test_training_with_ema)

@unittest.skip(
@pytest.mark.skip(
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
"that is mutated during the first forward. On recomputation the cache is already populated, "
"causing a different execution path and numerically different gradients. "
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
"causing a different execution path and numerically different gradients."
)
def test_effective_gradient_checkpointing(self):
pass
def test_gradient_checkpointing_equivalence(self):
super().test_gradient_checkpointing_equivalence()

def test_layerwise_casting_training(self):
self._run_nondeterministic(super().test_layerwise_casting_training)
_run_nondeterministic(super().test_layerwise_casting_training)


class TestAutoencoderKLKVAEVideoMemory(AutoencoderKLKVAEVideoTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLKVAEVideo."""


class TestAutoencoderKLKVAEVideoSlicingTiling(AutoencoderKLKVAEVideoTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLKVAEVideo."""
Loading
Loading