import time import argparse from dataclasses import dataclass from pathlib import Path from typing import List, Sequence import sys from datetime import datetime, timedelta import numpy as np import torch import torch.nn.functional as F from model_architect.UNet_DDPM import UNet_with_time, DDPM @dataclass class Config: input_frame: int = 12 output_frame: int = 6 cond_nc: int = 5 time_emb_dim: int = 128 base_chs: int = 32 chs_mult: tuple = (1, 2, 4, 8, 8) ## different resolution use_attn_list: tuple = (0, 0, 1, 1, 1) # 0 means no attention, 1 means use attention n_res_blocks: int = 2 n_steps: int = 1000 dropout: float = 0.1 def data_loading(BASETIME, device): data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz') inputs = {} for key in data_npz: inputs[key] = torch.from_numpy(data_npz[key]).to(device) return inputs def arg_parse(): parser = argparse.ArgumentParser() parser.add_argument( '--pred-hr', type=str, default='1hr', choices=[ '1hr', '6hr' ] ) parser.add_argument( '--pred-mode', type=str, default='DDPM', choices=[ 'DDPM', 'DDIM' ] ) parser.add_argument('--basetime', type=str, default='202504131100') args = parser.parse_args() return args if __name__ == "__main__": config = Config() args = arg_parse() pred_hr = args.pred_hr pred_mode = args.pred_mode BASETIME = args.basetime device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inputs = data_loading(BASETIME, device) model_config = Config() if pred_hr == '6hr': model_config.input_frame = 72 model_config.output_frame = 36 print("Prediction mode:", pred_mode) print("Prediction horizon:", pred_hr) ## preporcess inputs for DDPM model ## concat previous Himawari and topo as conditional input (B, 5, 512, 512) ## WRF dim: (B, 36, 512, 512). 1hr: (B, 6, 512, 512), 6hr: (B, 36, 512, 512) prev_himawari = inputs['Himawari'].squeeze(2) topo = inputs['topo'] input_ = torch.cat([prev_himawari, topo], dim=1) WRF = F.interpolate(inputs['WRF'].squeeze(2), scale_factor=4, mode='bilinear') clearsky = inputs['clearsky'] if pred_hr == '1hr': WRF = WRF[:, :6] clearsky = clearsky[:, :6] backbone = UNet_with_time(model_config) model = DDPM(backbone, output_shape=(model_config.output_frame, 512, 512)) ## load model weights if pred_hr == '1hr': ckpt_path = './model_weights/ft06_01hr/weights.ckpt' elif pred_hr == '6hr': ckpt_path = './model_weights/ft36_06hr/weights.ckpt' ckpt = torch.load(ckpt_path, weights_only=True) model.load_state_dict(ckpt['state_dict']) model.eval() model = model.to(device) if pred_mode == 'DDPM': pred_clr_idx = model.sample_ddpm( input_, input_cond=WRF, verbose="text" ) elif pred_mode == 'DDIM': pred_clr_idx = model.sample_ddim( input_, input_cond=WRF, ddim_steps=100, verbose="text" ) pred_clr_idx = (pred_clr_idx + 1.0) / 2.0 pred_clr_idx = pred_clr_idx.clamp(0.0, 1.0) ## transform clearsky index to solar radiation pred_srad = pred_clr_idx * clearsky ## save prediction np.save(f'./pred_{BASETIME}_{pred_hr}_{pred_mode}.npy', pred_srad.cpu().numpy()) print('Done')