Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Solubility(nn.Module):
- def __init__(self, n_features=30, in_channels=32, out_channels=16, gcn_in_channels=32, gcn_out_channels=16, dropout_rate=0.45):
- super(Solubility, self).__init__()
- torch.manual_seed(21)
- self.conv1 = GCNConv(n_features, in_channels)
- self.conv2 = GCNConv(in_channels, out_channels)
- # self.gcn = GCN(n_features=30, in_channels=gcn_in_channels, out_channels=gcn_out_channels)
- self.lightgbm = None
- self.fc = nn.Linear(out_channels + 1, 1)
- self.dropout_rate = dropout_rate
- self.featurizer = dc.feat.RDKitDescriptors(use_fragment=False, ipc_avg=False, use_bcut2d=True)
- self.normalizer = None
- self.reset_parameters()
- def forward(self, data):
- x, edge_index = data.x, data.edge_index
- x = self.conv1(x, edge_index).relu()
- x = self.conv2(x, edge_index).relu()
- x = global_mean_pool(x, data.batch)
- x = F.dropout(x, p=self.dropout_rate, training=self.training)
- assert self.lightgbm is not None, "Train LightGBM model first"
- if self.normalizer:
- predictions = self.lightgbm.predict(self.normalizer.transform(self.featurizer(data.smiles)))
- lightgbm_output = torch.tensor(predictions, dtype=torch.float32).unsqueeze(1)
- else:
- predictions = self.lightgbm.predict(self.featurizer(data.smiles))
- lightgbm_output = torch.tensor(predictions, dtype=torch.float32).unsqueeze(1)
- x = torch.hstack((x, lightgbm_output))
- x = self.fc(x)
- return x
- def _collect_dataset_data(self, dataset):
- smiles, targets = zip(*((data.smiles, data.y.numpy()[0]) for data in dataset))
- return np.array(smiles), np.array(targets)
- def train_lightgbm(self, train_dataset, valid_dataset, normalize_data=False):
- self.normalizer = DataNormalizer() if normalize_data else None
- train_smiles, y_train = self._collect_dataset_data(train_dataset)
- X_train = pd.DataFrame(data=self.featurizer(train_smiles), columns=self.featurizer.descriptors)
- print('X_train.shape=', X_train.shape)
- if normalize_data: X_train = self.normalizer.fit_transform(X_train)
- valid_smiles, y_valid = self._collect_dataset_data(valid_dataset)
- X_valid = pd.DataFrame(data=self.featurizer(valid_smiles), columns=self.featurizer.descriptors)
- print('X_valid.shape=', X_valid.shape)
- if normalize_data: X_valid = self.normalizer.transform(X_valid)
- train_data = lgb.Dataset(X_train, label=y_train)
- valid_data = lgb.Dataset(X_valid, label=y_valid, reference=train_data)
- hyperparams = {
- 'objective': 'regression',
- 'metric': 'mse',
- 'boosting':'dart',
- 'n_estimators': 10000,
- 'early_stopping_rounds': 100,
- 'learning_rate': 0.3,
- 'feature_fraction': 0.45,
- 'bagging_freq': 5,
- 'bagging_fraction': 0.9,
- 'bagging_seed': 42
- }
- self.lightgbm = lgb.train(hyperparams, train_data,
- valid_sets=[train_data, valid_data],
- valid_names=['train', 'valid'])
- def reset_parameters(self):
- self.conv1.reset_parameters()
- self.conv2.reset_parameters()
- self.fc.reset_parameters()
- def save_models(self):
- import joblib
- torch.save(self.state_dict(), 'gcn_fc_weights.pth')
- joblib.dump(self.lightgbm, 'lgb_solubility_model.pkl')
- def load_models(self,
- gcn_fc_path='gcn_fc_weights.pth',
- lbg_path='lgb_solubility_model.pkl'):
- import joblib
- self.load_state_dict(torch.load(gcn_fc_path))
- self.lightgbm = joblib.load(lbg_path)
- @classmethod
- def from_pickle(cls,
- gcn_fc_path='gcn_fc_weights.pth',
- lbg_path='lgb_solubility_model.pkl'):
- m = cls()
- m.load_model(gcn_fc_path, lbg_path)
- return m
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement