Advertisement
YomoMan

GCN+LightGBM.

Jun 7th, 2023
804
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.15 KB | None | 0 0
  1. class Solubility(nn.Module):
  2.     def __init__(self, n_features=30, in_channels=32, out_channels=16, gcn_in_channels=32, gcn_out_channels=16, dropout_rate=0.45):
  3.         super(Solubility, self).__init__()
  4.         torch.manual_seed(21)
  5.         self.conv1 = GCNConv(n_features, in_channels)
  6.         self.conv2 = GCNConv(in_channels, out_channels)
  7.  
  8.         # self.gcn = GCN(n_features=30, in_channels=gcn_in_channels, out_channels=gcn_out_channels)
  9.         self.lightgbm = None
  10.         self.fc = nn.Linear(out_channels + 1, 1)
  11.        
  12.         self.dropout_rate = dropout_rate
  13.         self.featurizer = dc.feat.RDKitDescriptors(use_fragment=False, ipc_avg=False, use_bcut2d=True)
  14.         self.normalizer = None
  15.  
  16.         self.reset_parameters()
  17.  
  18.     def forward(self, data):
  19.         x, edge_index = data.x, data.edge_index
  20.         x = self.conv1(x, edge_index).relu()
  21.         x = self.conv2(x, edge_index).relu()
  22.         x = global_mean_pool(x, data.batch)
  23.         x = F.dropout(x, p=self.dropout_rate, training=self.training)
  24.        
  25.         assert self.lightgbm is not None, "Train LightGBM model first"
  26.  
  27.         if self.normalizer:
  28.             predictions = self.lightgbm.predict(self.normalizer.transform(self.featurizer(data.smiles)))
  29.             lightgbm_output = torch.tensor(predictions, dtype=torch.float32).unsqueeze(1)                
  30.         else:
  31.             predictions = self.lightgbm.predict(self.featurizer(data.smiles))
  32.             lightgbm_output = torch.tensor(predictions, dtype=torch.float32).unsqueeze(1)
  33.         x = torch.hstack((x, lightgbm_output))
  34.  
  35.         x = self.fc(x)
  36.         return x
  37.  
  38.    
  39.     def _collect_dataset_data(self, dataset):
  40.         smiles, targets = zip(*((data.smiles, data.y.numpy()[0]) for data in dataset))
  41.         return np.array(smiles), np.array(targets)
  42.  
  43.     def train_lightgbm(self, train_dataset, valid_dataset, normalize_data=False):    
  44.         self.normalizer = DataNormalizer() if normalize_data else None
  45.  
  46.         train_smiles, y_train = self._collect_dataset_data(train_dataset)
  47.         X_train = pd.DataFrame(data=self.featurizer(train_smiles), columns=self.featurizer.descriptors)
  48.         print('X_train.shape=', X_train.shape)
  49.         if normalize_data: X_train = self.normalizer.fit_transform(X_train)
  50.        
  51.         valid_smiles, y_valid = self._collect_dataset_data(valid_dataset)
  52.         X_valid = pd.DataFrame(data=self.featurizer(valid_smiles), columns=self.featurizer.descriptors)    
  53.         print('X_valid.shape=', X_valid.shape)  
  54.         if normalize_data: X_valid = self.normalizer.transform(X_valid)
  55.        
  56.         train_data = lgb.Dataset(X_train, label=y_train)
  57.         valid_data = lgb.Dataset(X_valid, label=y_valid, reference=train_data)
  58.    
  59.         hyperparams = {
  60.             'objective': 'regression',
  61.             'metric': 'mse',
  62.             'boosting':'dart',
  63.             'n_estimators': 10000,
  64.             'early_stopping_rounds': 100,
  65.             'learning_rate': 0.3,
  66.             'feature_fraction': 0.45,
  67.             'bagging_freq': 5,
  68.             'bagging_fraction': 0.9,
  69.             'bagging_seed': 42
  70.         }
  71.         self.lightgbm = lgb.train(hyperparams, train_data,
  72.                                   valid_sets=[train_data, valid_data],
  73.                                   valid_names=['train', 'valid'])
  74.        
  75.     def reset_parameters(self):
  76.         self.conv1.reset_parameters()
  77.         self.conv2.reset_parameters()
  78.         self.fc.reset_parameters()
  79.        
  80.     def save_models(self):
  81.         import joblib
  82.         torch.save(self.state_dict(), 'gcn_fc_weights.pth')
  83.         joblib.dump(self.lightgbm, 'lgb_solubility_model.pkl')
  84.        
  85.     def load_models(self,
  86.                     gcn_fc_path='gcn_fc_weights.pth',
  87.                     lbg_path='lgb_solubility_model.pkl'):
  88.         import joblib
  89.         self.load_state_dict(torch.load(gcn_fc_path))
  90.         self.lightgbm = joblib.load(lbg_path)
  91.        
  92.     @classmethod
  93.     def from_pickle(cls,
  94.                     gcn_fc_path='gcn_fc_weights.pth',
  95.                     lbg_path='lgb_solubility_model.pkl'):
  96.         m = cls()
  97.         m.load_model(gcn_fc_path, lbg_path)
  98.         return m
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement