Video Generation Latent Space Memory Estimation

Detection, Video & Advanced Vision DS practice problem on Onlearn.

Difficulty: medium.

Topics: Video Generation Latent Space Memory Estimation, KL-Divergence Regularization, VRAM Allocation Profiling, Latent Diffusion Bottleneck, Temporal Attention Sparsity, Autoencoder Compression Ratio, Generative Modeling, Information Theory, Computational Complexity, Computer Vision, Statistical Learning, Latent Space Manifold Learning, Diffusion Probabilistic Models, Memory Bottleneck Analysis, High-Dimensional Tensor Operations, Temporal Consistency Modeling.

Modern video generation models (such as those based on Diffusion Transformers) encode raw video frames into a compressed latent space using a 3D Variational Autoencoder. The VAE compresses both spatial dimensions (height, width) and the temporal dimension (number of frames). Understanding the memory footprint of these latent representations is crucial for planning GPU resources during training and inference. Write a function estimate video latent memory that computes the latent tensor shape, total memory consumption, and the number of transformer tokens produced after optional spatial patchification. The function should accept the following parameters: num frames (int): Number of video frames height (int): Pixel height of each frame width (int): Pixel width of each frame latent channels (int): Number of channels in the latent space spatial compression (int): Factor by which height and width are downsampled temporal compression (int): Factor by which the frame count is downsampled batch size (int, default 1): Number of videos in the batch dtype (str, default "fp16"): Data type, one of "fp32", "fp16", "bf16", "fp8" patch size (int, default 1): Spatial patch size for converting latent features into transformer tokens The function should return a dictionary with the following keys: latent shape: A tuple (batch size, latent channels, latent t, latent h, latent w) num elements: Total number of elements in the latent tensor (int) memory bytes: Total memory in bytes (int) memory mb: Total memory in megabytes, rounded to 4 decimal places (float) tokens per video: Number of transformer tokens per single video after patchification (int) Use ceiling division when computing compressed and patchified dimensions to handle non divisible sizes.