Create a wrapper model to include normalization and de-normalization inside model’s forward method

%cd /global/u2/z/zeyuanhu/public_codes/ClimSim/
from climsim_utils.data_utils import *
/usr/local/lib/python3.10/dist-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
/global/u2/z/zeyuanhu/public_codes/ClimSim
2024-07-23 03:40:28.985421: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-23 03:40:28.985454: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-23 03:40:28.986968: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-23 03:40:28.994981: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn
import modulus

Create a wrapper model to include normalization and de-normalization inside model’s forward method#

We define below a new class “NewModel” that takes the trained U-Net model (v5, i.e., applied microphysics constraints) and include all the preprocessing and post-processing steps inside the forward method.

%cd /global/u2/z/zeyuanhu/public_codes/ClimSim/online_testing/baseline_models/Unet_v5/training
from climsim_unet import ClimsimUnet
import climsim_unet as climsim_unet
/global/u2/z/zeyuanhu/public_codes/ClimSim/online_testing/baseline_models/Unet_v5/training
/usr/local/lib/python3.10/dist-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
class NewModel(nn.Module):
    def __init__(self, original_model, input_sub, input_div, out_scale, lbd_qn):
        super(NewModel, self).__init__()
        self.original_model = original_model
        self.input_sub = torch.tensor(input_sub, dtype=torch.float32)
        self.input_div = torch.tensor(input_div, dtype=torch.float32)
        self.out_scale = torch.tensor(out_scale, dtype=torch.float32)
        self.lbd_qn = torch.tensor(lbd_qn, dtype=torch.float32)
    
    def apply_temperature_rules(self, T):
        # Create an output tensor, initialized to zero
        output = torch.zeros_like(T)

        # Apply the linear transition within the range 253.16 to 273.16
        mask = (T >= 253.16) & (T <= 273.16)
        output[mask] = (T[mask] - 253.16) / 20.0  # 20.0 is the range (273.16 - 253.16)

        # Values where T > 273.16 set to 1
        output[T > 273.16] = 1

        # Values where T < 253.16 are already set to 0 by the initialization
        return output

    def preprocessing(self, x):
        
        # convert v4 input array to v5 input array:
        xout = x
        xout_new = torch.zeros((xout.shape[0], 1405), dtype=xout.dtype)
        xout_new[:,0:120] = xout[:,0:120]
        xout_new[:,120:180] = xout[:,120:180] + xout[:,180:240]
        xout_new[:,180:240] = self.apply_temperature_rules(xout[:,0:60])
        xout_new[:,240:840] = xout[:,240:840] #60*14
        xout_new[:,840:900] = xout[:,840:900]+ xout[:,900:960] #dqc+dqi
        xout_new[:,900:1080] = xout[:,960:1140]
        xout_new[:,1080:1140] = xout[:,1140:1200]+ xout[:,1200:1260]
        xout_new[:,1140:1405] = xout[:,1260:1525]
        x = xout_new
        
        #do input normalization
        x[:,120:180] = 1 - torch.exp(-x[:,120:180] * self.lbd_qn)
        x= (x - self.input_sub) / self.input_div
        x = torch.where(torch.isnan(x), torch.tensor(0.0, device=x.device), x)
        x = torch.where(torch.isinf(x), torch.tensor(0.0, device=x.device), x)
        
        #prune top 15 levels in qn input
        x[:,120:120+15] = 0
        #clip rh input
        x[:, 60:120] = torch.clamp(x[:, 60:120], 0, 1.2)
        return x

    def postprocessing(self, x):
        x[:,60:75] = 0
        x[:,120:135] = 0
        x[:,180:195] = 0
        x[:,240:255] = 0
        x = x/self.out_scale
        return x

    def forward(self, x):
        t_before = x[:,0:60].clone()
        qc_before = x[:,120:180].clone()
        qi_before = x[:,180:240].clone()
        qn_before = qc_before + qi_before
        
        x = self.preprocessing(x)
        x = self.original_model(x)
        x = self.postprocessing(x)
        
        t_new = t_before + x[:,0:60]*1200.
        qn_new = qn_before + x[:,120:180]*1200.
        liq_frac = self.apply_temperature_rules(t_new)
        qc_new = liq_frac*qn_new
        qi_new = (1-liq_frac)*qn_new
        xout = torch.zeros((x.shape[0],368))
        xout[:,0:120] = x[:,0:120]
        xout[:,240:] = x[:,180:]
        xout[:,120:180] = (qc_new - qc_before)/1200.
        xout[:,180:240] = (qi_new - qi_before)/1200.
    
        return xout
def save_wrapper(casename):
    # casename = 'v5_noclassifier_huber_1y_noaggressive'
    f_torch_model = f'/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/saved_models/{casename}/model.mdlus'
    f_inp_sub     = f'/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/saved_models/{casename}/inp_sub.txt'
    f_inp_div     = f'/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/saved_models/{casename}/inp_div.txt'
    f_out_scale   = f'/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/saved_models/{casename}/out_scale.txt'
    f_qn_lbd = '/global/u2/z/zeyuanhu/nvidia_codes/Climsim_private/preprocessing/normalizations/inputs/qn_exp_lambda_large.txt'
    lbd_qn = np.loadtxt(f_qn_lbd, delimiter=',')
    input_sub = np.loadtxt(f_inp_sub, delimiter=',')
    input_div = np.loadtxt(f_inp_div, delimiter=',')
    out_scale = np.loadtxt(f_out_scale, delimiter=',')
    model_inf = modulus.Module.from_checkpoint(f_torch_model).to('cpu')

    new_model = NewModel(model_inf, input_sub, input_div, out_scale, lbd_qc, lbd_qi)

    NewModel.device = "cpu"
    device = torch.device("cpu")
    scripted_model = torch.jit.script(new_model)
    scripted_model = scripted_model.eval()
    save_file_torch = os.path.join('/global/homes/z/zeyuanhu/scratch/hugging/E3SM-MMF_ne4/saved_models_wrapper_tmp/', f'{casename}.pt')
    scripted_model.save(save_file_torch)
    return None
save_wrapper('v5_unet_nonaggressive_cliprh_huber_rop2_r2')
save_wrapper('v5_unet_nonaggressive_cliprh_huber')
save_wrapper('v5_unet_nonaggressive_cliprh_mae')