Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import wget | |
| import json | |
| import os | |
| IMAGEGEN_FOLDER = "./ImageGenModel" | |
| IMAGEGEN_MODEL_WEIGHTS = "diffusion_pytorch_model.bin" | |
| IMAGEGEN_CONFIG = "config.json" | |
| IMAGEGEN_MODEL_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin" | |
| IMAGEGEN_CONFIG_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json" | |
| IMAGEGEN_FILES_URLS = [ | |
| (IMAGEGEN_MODEL_URL, IMAGEGEN_MODEL_WEIGHTS), | |
| (IMAGEGEN_CONFIG_URL, IMAGEGEN_CONFIG), | |
| ] | |
| def ensure_imagegen_files_exist(): | |
| os.makedirs(IMAGEGEN_FOLDER, exist_ok=True) | |
| for url, filename in IMAGEGEN_FILES_URLS: | |
| filepath = os.path.join(IMAGEGEN_FOLDER, filename) | |
| if not os.path.exists(filepath): | |
| wget.download(url, out=filepath) | |
| class UNet2DConditionModelConfig: | |
| def __init__(self, **kwargs): | |
| self.sample_size = 64 | |
| self.layers_per_block = 2 | |
| self.block_out_channels = [320, 640, 1280, 1280] | |
| self.downsample = [2, 2, 2, 2] | |
| self.upsample = [2, 2, 2, 2] | |
| self.cross_attention_dim = 768 | |
| self.act_fn = "silu" | |
| self.norm_num_groups = 32 | |
| self.num_attention_heads = 8 | |
| for key, value in kwargs.items(): | |
| setattr(self, key, value) | |
| def from_dict(cls, config_dict): | |
| return cls(**config_dict) | |
| class UNet2DConditionModel(nn.Module): | |
| def __init__(self, config: UNet2DConditionModelConfig): | |
| super().__init__() | |
| self.conv_in = nn.Conv2d(4, config.block_out_channels[0], kernel_size=3, padding=1) | |
| self.down_blocks = nn.ModuleList([]) | |
| for i in range(len(config.block_out_channels)): | |
| is_final_block = i == len(config.block_out_channels) - 1 | |
| downsample_factor = 1 if is_final_block else config.downsample[i] | |
| out_channels = config.block_out_channels[i] | |
| layers_per_block = config.layers_per_block | |
| self.down_blocks.append(DownBlock(out_channels, layers_per_block, downsample_factor)) | |
| self.mid_block = MidBlock(config.block_out_channels[-1]) | |
| self.up_blocks = nn.ModuleList([]) | |
| reversed_block_out_channels = list(reversed(config.block_out_channels)) | |
| reversed_upsample_factors = list(reversed(config.upsample)) | |
| for i in range(len(config.block_out_channels)): | |
| is_final_block = i == len(config.block_out_channels) - 1 | |
| upsample_factor = 1 if is_final_block else reversed_upsample_factors[i] | |
| out_channels = reversed_block_out_channels[i] | |
| layers_per_block = config.layers_per_block | |
| self.up_blocks.append(UpBlock(out_channels, layers_per_block, upsample_factor)) | |
| self.norm_out = nn.GroupNorm(num_groups=config.norm_num_groups, num_channels=config.block_out_channels[0]) | |
| self.conv_norm_out = nn.Conv2d(config.block_out_channels[0], config.block_out_channels[0], kernel_size=3, padding=1) | |
| self.conv_out = nn.Conv2d(config.block_out_channels[0], 4, kernel_size=3, padding=1) | |
| def forward(self, sample: torch.FloatTensor, timestep: torch.IntTensor, encoder_hidden_states: torch.FloatTensor): | |
| sample = self.conv_in(sample) | |
| for down_block in self.down_blocks: | |
| sample = down_block(sample) | |
| sample = self.mid_block(sample) | |
| for up_block in self.up_blocks: | |
| sample = up_block(sample) | |
| sample = self.norm_out(sample) | |
| sample = F.silu(sample) | |
| sample = self.conv_norm_out(sample) | |
| sample = F.silu(sample) | |
| sample = self.conv_out(sample) | |
| return {"sample": sample} | |
| class DownBlock(nn.Module): | |
| def __init__(self, out_channels, layers_per_block, downsample_factor): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ResnetBlock(out_channels) for _ in range(layers_per_block)]) | |
| if downsample_factor > 1: | |
| self.downsample = Downsample2D(out_channels, downsample_factor) | |
| else: | |
| self.downsample = nn.Identity() | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| x = self.downsample(x) | |
| return x | |
| class UpBlock(nn.Module): | |
| def __init__(self, out_channels, layers_per_block, upsample_factor): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ResnetBlock(out_channels) for _ in range(layers_per_block)]) | |
| if upsample_factor > 1: | |
| self.upsample = Upsample2D(out_channels, upsample_factor) | |
| else: | |
| self.upsample = nn.Identity() | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| x = self.upsample(x) | |
| return x | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels) | |
| self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
| self.norm2 = nn.GroupNorm(num_groups=32, num_channels=channels) | |
| self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
| self.residual_conv = nn.Conv2d(channels, channels, kernel_size=1) | |
| def forward(self, x): | |
| residual = x | |
| x = self.norm1(x) | |
| x = F.silu(x) | |
| x = self.conv1(x) | |
| x = self.norm2(x) | |
| x = F.silu(x) | |
| x = self.conv2(x) | |
| return x + self.residual_conv(residual) | |
| class MidBlock(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels) | |
| self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
| self.norm2 = nn.GroupNorm(num_groups=32, num_channels=channels) | |
| self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| x = self.norm1(x) | |
| x = F.silu(x) | |
| x = self.conv1(x) | |
| x = self.norm2(x) | |
| x = F.silu(x) | |
| x = self.conv2(x) | |
| return x | |
| class Downsample2D(nn.Module): | |
| def __init__(self, channels, factor): | |
| super().__init__() | |
| self.factor = factor | |
| self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=factor, padding=1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Upsample2D(nn.Module): | |
| def __init__(self, channels, factor): | |
| super().__init__() | |
| self.factor = factor | |
| self.conv = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor) | |
| def forward(self, x): | |
| return self.conv(x) |