pavel_777

transformer pipeline

Feb 3rd, 2022 (edited)
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.93 KB | None | 0 0
  1. # обучение трансформера по последовательности 128 последних кластеров для предсказания множества из 20 кластеров
  2.  
  3. import pandas as pd
  4. import numpy as np
  5.  
  6. import torch
  7. from torch import nn
  8. from torch.utils.data import DataLoader, Dataset
  9. from torch.nn import functional as F
  10.  
  11. import pytorch_lightning as pl
  12. from pl_bolts.datasets import DummyDataset
  13.  
  14. from sklearn.model_selection import train_test_split
  15.  
  16. import math
  17. import gc
  18. import time
  19.  
  20.  
  21. str_device = "cuda" if torch.cuda.is_available() else "cpu"
  22. device = torch.device(str_device)
  23.  
  24.  
  25. input_path = './input'
  26.  
  27.  
  28. PATH_PL = 'sber-tran-9.ckpt'
  29.  
  30. class params:
  31.     n_classes = 8_000 #  
  32.     emsize = 300 # embedding dimension
  33.     nhead = 6 # the number of heads in the multiheadattention models
  34.     nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
  35.     nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
  36.     dropout = 0.1 # the dropout value
  37.     batch_size = 1_024
  38.     seq_size = 128
  39.     lr=0.001 # 1e-3
  40.     test_size = 0.1
  41.     random_state = 12345
  42.     epochs = 6
  43.    
  44.  
  45. train_data = pd.read_parquet(f'{input_path}/train.parquet')
  46.  
  47. cluster_weights = pd.read_parquet(f'{input_path}/cluster-weights.parquet')
  48. cluster_weights = cluster_weights.sort_values(by='cluster_id', ascending=True)
  49.  
  50. w_cluster = cluster_weights['w'].values
  51. print(w_cluster.shape, w_cluster)
  52.  
  53. idx_cluster = cluster_weights['cluster_id'].values
  54. print(idx_cluster.shape, idx_cluster)
  55.  
  56. w = np.zeros(params.n_classes) + 0.1
  57. w[idx_cluster] = w_cluster
  58. print(w[:10], w.shape)
  59.  
  60. train_data['month'] = train_data['completed_at'].dt.month
  61.  
  62.  
  63. df_m = train_data.groupby(['id', 'month'])['cluster_id'].apply(list)
  64.  
  65. df_month_h = df_m.unstack(fill_value='').reset_index()
  66.  
  67. df_month_h.columns = ['id', 'm_6', 'm_7', 'm_8', 'm_9']
  68.  
  69.  
  70. del train_data
  71. gc.collect()
  72.  
  73.  
  74. # m_6 ... m_9
  75. def add_prev(r):
  76.     a = r.prev
  77.     if type(a)==str: # nan
  78.         a=[]
  79.     a = list(set(a))
  80.    
  81.     if len(a)>=params.seq_size:
  82.         return a[:params.seq_size]
  83.    
  84.     for i in range(r.month-2, 5, -1):
  85.         clusters = r[f'm_{i}']
  86.         if type(clusters)==str: # nan
  87.             continue
  88.         if len(clusters) > 0:
  89.             a = list(set(a).union(set(clusters)))
  90.         if len(a)>=params.seq_size:
  91.             return a[:params.seq_size]
  92.    
  93.     return a[:params.seq_size]
  94.  
  95.  
  96. def get_train_df():
  97.     df = None
  98.  
  99.     for i in range(7,10):
  100.         print('month', i)
  101.         df_temp = df_month_h[df_month_h[f'm_{i}']!=''].copy()
  102.         df_temp['prev'] = df_temp[f'm_{i-1}']
  103.         df_temp['target'] = df_temp[f'm_{i}']
  104.         df_temp['month'] = i
  105.  
  106.         if i > 7:
  107.             df_temp['prev'] = df_temp.apply(add_prev, axis=1)
  108.         else:
  109.             df_temp['prev'] = df_temp['prev'].apply(lambda x: x if type(x)!=str else [])
  110.        
  111.         if df is None:
  112.             df = df_temp
  113.         else:
  114.             df = pd.concat([df, df_temp])
  115.         del df_temp
  116.         gc.collect()
  117.    
  118.        
  119.     return df[['id', 'prev', 'target']]
  120.  
  121.  
  122. df = get_train_df()    
  123.  
  124.  
  125. df_train, df_val = train_test_split(
  126.     df,
  127.     test_size=params.test_size,
  128.     random_state=params.random_state,
  129. )
  130.  
  131.  
  132. print(df_train.shape, df_val.shape)
  133.  
  134.  
  135. del df_month_h
  136. del df
  137. gc.collect()
  138.  
  139. class MyDummyDataset(Dataset):
  140.     def __init__(self, df):
  141.           self.data = df
  142.        
  143.     def __len__(self):
  144.         return len(self.data)
  145.    
  146.     def __getitem__(self, idx):
  147.         prev = self.data.iloc[idx]['prev']
  148.         target_s = set(self.data.iloc[idx]['target'])
  149.        
  150.         x_l = prev[:params.seq_size]
  151.         if len(x_l) < params.seq_size:
  152.             x_l = x_l + [params.n_classes] * (params.seq_size-len(x_l))
  153.            
  154.         diff_s = target_s - set(prev)
  155.         ost_s = target_s - diff_s
  156.         target_idx = list(diff_s)
  157.         if len(diff_s) < 20:
  158.             ost_l = list(ost_s)
  159.             if len(ost_l) > 0:
  160.                 df = pd.DataFrame({'c':ost_l, 'w':w[ost_l]})
  161.                 df = df.sort_values(by='w', ascending=False)
  162.                 target_idx = list(diff_s) + list(df.head(20-len(diff_s))['c'].values)
  163.        
  164.         Y = np.zeros(params.n_classes)
  165.         Y[target_idx] = 1.0
  166.  
  167.         return torch.tensor(x_l).int(), Y
  168.  
  169. train = MyDummyDataset(df_train)
  170. train = DataLoader(train, batch_size=params.batch_size, num_workers=0)
  171.  
  172. val = MyDummyDataset(df_val)
  173. val = DataLoader(val, batch_size=params.batch_size, num_workers=0)
  174.  
  175.  
  176. class MyTransformer(pl.LightningModule):
  177.  
  178.     def __init__(self):
  179.         super().__init__()
  180.        
  181.         self.src_mask = None
  182.         encoder_layers = nn.TransformerEncoderLayer(d_model=params.emsize,
  183.                                                  nhead=params.nhead,
  184.                                                  dim_feedforward=params.nhid,
  185.                                                  dropout=params.dropout,
  186. #                                                  batch_first=True,
  187.                                                 )
  188.         self.transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layers,
  189.                                                       num_layers=params.nlayers,
  190.                                                      )
  191.         self.encoder = nn.Embedding(params.n_classes + 1, params.emsize) # еще служебный 8_000
  192.         self.emsize = params.emsize
  193.         self.pool = nn.AvgPool2d((params.seq_size, 1))
  194.         self.decoder = nn.Linear(params.emsize, params.n_classes)
  195.         self.w_1 = w
  196.  
  197.         self.init_weights()
  198.        
  199.     def _generate_square_subsequent_mask(self, sz):
  200.         mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  201.         mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  202.        
  203.         mask = torch.zeros(sz, sz)
  204.         return mask
  205.  
  206.     def init_weights(self):
  207.         initrange = 0.1
  208.         self.encoder.weight.data.uniform_(-initrange, initrange)
  209.         self.decoder.bias.data.zero_()
  210.         self.decoder.weight.data.uniform_(-initrange, initrange)
  211.  
  212.     def forward(self, x):
  213.         if self.src_mask is None or self.src_mask.size(0) != len(x):
  214.             device = x.device
  215.             mask = self._generate_square_subsequent_mask(params.seq_size).to(device)
  216.             self.src_mask = mask
  217.  
  218.         src = self.encoder(x) * math.sqrt(self.emsize)
  219.         src = src.permute(1, 0, 2)
  220.  
  221.         output = self.transformer_encoder(src, self.src_mask)
  222.         output = output.permute(1, 0, 2)
  223.  
  224.         output = self.pool(output)
  225.         x_hat = self.decoder(output)
  226.         return x_hat
  227.    
  228.     def training_step(self, batch, batch_idx):
  229.         x, y = batch
  230.         x_hat = self(x)
  231.         w_all = self.get_w()
  232.         loss = F.binary_cross_entropy_with_logits(x_hat.view(-1, params.n_classes), y, weight=w_all)
  233.         self.log('train_loss', loss)
  234.         return loss
  235.  
  236.     def validation_step(self, batch, batch_idx):
  237.         x, y = batch
  238.         x_hat = self(x)
  239.         w_all = self.get_w()
  240.         loss = F.binary_cross_entropy_with_logits(x_hat.view(-1, params.n_classes), y, weight=w_all)
  241.         self.log('val_loss', loss, prog_bar=True)
  242.         return loss
  243.    
  244.     def get_w(self):
  245.         w_all = torch.tensor(self.w_1).float().to(device)
  246.         return w_all
  247.  
  248.     def test_step(self, batch, batch_idx):
  249.         return self.validation_step(batch, batch_idx)
  250.  
  251.     def configure_optimizers(self):
  252.         optimizer = torch.optim.Adam(self.parameters(), lr=params.lr)
  253.         return optimizer
  254.  
  255.  
  256. model = MyTransformer()
  257.  
  258.  
  259. if str_device == "cuda":
  260.     trainer = pl.Trainer(gpus=1, max_epochs=params.epochs, progress_bar_refresh_rate=20)
  261. else:
  262.     trainer = pl.Trainer(gpus=0, max_epochs=params.epochs, progress_bar_refresh_rate=20)
  263.  
  264. # Train the model
  265. trainer.fit(model, train, val)
  266.  
  267. trainer.save_checkpoint(PATH_PL)
  268.  
  269.  
Add Comment
Please, Sign In to add comment