Advertisement
iSach

Untitled

Jan 2nd, 2024
42
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.83 KB | None | 0 0
  1. import torch
  2. import numpy as np
  3. import pandas as pd
  4.  
  5. from torch.utils.data import Dataset, DataLoader
  6. import torch.nn as nn
  7.  
  8. import re
  9. from random import shuffle
  10. from time import time
  11. import pickle
  12.  
  13. import wandb
  14.  
  15. from dawgz import job, schedule
  16. import os
  17. import json
  18.  
  19. def tokenize(review):
  20. review = review.lower()
  21. review = re.sub(r'[^a-zA-Z0-9_ ]', '', review)
  22. review = review.split(' ')
  23. review = list(filter(lambda x: x != '', review))
  24. return review
  25.  
  26. def load_voc():
  27. global i2t
  28. global t2i
  29. global eos_token_id
  30. global pad_token_id
  31. global unk_token_id
  32.  
  33. with open ('voc', 'rb') as fp:
  34. i2t = pickle.load(fp)
  35. t2i = {i2t[i]: i for i in range(len(i2t))}
  36. eos_token_id = t2i['<eos>']
  37. pad_token_id = t2i['<pad>']
  38. unk_token_id = t2i['<unk>']
  39.  
  40. class SteamReviewsDataset(Dataset):
  41. def __init__(self, reviews, scores):
  42. self.reviews = reviews
  43. self.scores = scores
  44.  
  45. def __len__(self):
  46. return len(self.reviews)
  47.  
  48. def __getitem__(self, idx):
  49. return self.reviews[idx], self.scores[idx]
  50.  
  51. def collate_fn(data):
  52. reviews, scores = zip(*data)
  53. max_len = max([len(d[0]) for d in data]) + 1
  54. def eos_pad(review):
  55. len_r = len(review)
  56. review = torch.cat([review, torch.tensor([eos_token_id])])
  57. if len_r == max_len - 1:
  58. return review
  59. pad = torch.tensor([pad_token_id] * (max_len - len_r - 1))
  60. return torch.cat([pad, review])
  61.  
  62. reviews = list(map(eos_pad, reviews))
  63.  
  64. return torch.stack(reviews), torch.stack(scores)
  65.  
  66. def load_from_csv(N: int, file_name: str):
  67. data: pd.DataFrame = pd.read_csv(f'{file_name}.csv', nrows=N)
  68.  
  69. X: pd.Series = data['review_text'].astype(str)
  70. X = X.apply(tokenize)
  71. X = X.apply(
  72. lambda review: torch.tensor([t2i.get(word, unk_token_id) for word in review])
  73. ).tolist()
  74.  
  75. y = torch.tensor(data['review_score'].to_numpy()).to(torch.float32)
  76.  
  77. del data
  78.  
  79. return X, y
  80.  
  81. def make_data(N_total: int, N_train: int, batch_size: int, t2i: dict, unk_token_id: int, device: str = 'cpu'):
  82. Xp, yp = load_from_csv(N_total // 2, 'data_positive_50to500')
  83. Xn, yn = load_from_csv(N_total // 2, 'data_negative_50to500')
  84.  
  85. X = Xp + Xn
  86. y = torch.cat([yp, yn])
  87.  
  88. perm = torch.randperm(N_total)
  89. X, y = [X[i] for i in perm], y[perm]
  90.  
  91. X_train, y_train = X[:N_train], y[:N_train]
  92. X_test, y_test = X[N_train:], y[N_train:]
  93.  
  94. dataset_train = SteamReviewsDataset(X_train, y_train)
  95. dataset_test = SteamReviewsDataset(X_test, y_test)
  96.  
  97. train_dl = DataLoader(dataset_train, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator(y.device))
  98. test_dl = DataLoader(dataset_test, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator(y.device))
  99.  
  100. return train_dl, test_dl
  101.  
  102. class SentimentClassifier(nn.Module):
  103. def __init__(self,
  104. emb_size=128,
  105. hidden_size=256,
  106. bidirectional=False,
  107. mode='last',
  108. voc_size=10_000):
  109. """
  110. mode: 'last', 'mean'
  111. """
  112. super().__init__()
  113.  
  114. self.mode = mode
  115.  
  116. self.embedding = nn.Parameter(torch.randn(voc_size, emb_size))
  117. self.lstm = nn.LSTM(
  118. emb_size,
  119. hidden_size,
  120. bidirectional=bidirectional,
  121. batch_first=True,
  122. )
  123. self.dropout = nn.Dropout(0.15)
  124. D = 2 if bidirectional else 1
  125. out_size = hidden_size * D
  126. self.out_mlp = nn.Sequential(
  127. nn.Linear(out_size, 1),
  128. nn.Sigmoid(),
  129. )
  130.  
  131. def forward(self, x):
  132. emb = self.embedding[x]
  133. output, _ = self.lstm(emb)
  134. if self.mode == 'mean':
  135. output = output.mean(dim=1)
  136. elif self.mode == 'last':
  137. output = output[:, -1, :]
  138. output = self.dropout(output)
  139. output = self.out_mlp(output)
  140. return output.squeeze()
  141.  
  142. CONFIGS = [
  143. {
  144. 'name': f'{mode} emb={e} h={h} lr={lr}',
  145. 'emb_size': e,
  146. 'hidden_size': h,
  147. 'lr': lr,
  148. 'mode': mode,
  149. }
  150. for mode in ['last', 'mean']
  151. for e in [32, 64, 128, 256, 512]
  152. for h in [32, 64, 128, 256, 512]
  153. for lr in [1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
  154. ]
  155.  
  156. @job(
  157. #array=1,
  158. array=len(CONFIGS),
  159. partition="a5000,tesla,quadro,2080ti",
  160. cpus=4,
  161. gpus=1,
  162. ram="16GB",
  163. time="24:00:00",
  164. name="fctlstm",
  165. )
  166. def train(i: int):
  167. torch.set_default_device('cpu')
  168.  
  169. run_config = CONFIGS[i]
  170.  
  171. load_voc()
  172.  
  173. # Train
  174. sc = SentimentClassifier(
  175. emb_size=run_config['emb_size'],
  176. hidden_size=run_config['hidden_size'],
  177. voc_size=len(i2t),
  178. mode=run_config['mode'],
  179. )
  180.  
  181. nb_params = sum(p.numel() for p in sc.parameters())
  182.  
  183. lr=run_config['lr']
  184. batch_size=64
  185. opt = torch.optim.Adam(sc.parameters(), lr=lr)
  186. loss = nn.BCELoss()
  187. def acc(preds, gt):
  188. preds, gt = preds.squeeze(), gt.squeeze()
  189. return (torch.sum(torch.round(preds) == gt) / len(preds) * 100).item()
  190. def acc_neg(preds, gt):
  191. preds, gt = preds.squeeze(), gt.squeeze()
  192. preds = preds[gt == 0.]
  193. gt = gt[gt == 0.]
  194. if len(preds) == 0:
  195. return 100.0
  196. return (torch.sum(torch.round(preds) == gt) / len(preds) * 100).item()
  197.  
  198. def acc_pos(preds, gt):
  199. preds, gt = preds.squeeze(), gt.squeeze()
  200. preds = preds[gt == 1.]
  201. gt = gt[gt == 1.]
  202. if len(preds) == 0:
  203. return 100.0
  204. return (torch.sum(torch.round(preds) == gt) / len(preds) * 100).item()
  205.  
  206. wandb_enabled = True
  207. if wandb_enabled:
  208. wandb.init(
  209. project='sentiment',
  210. name=run_config['name'],
  211. config={
  212. 'lr': lr,
  213. 'batch_size': batch_size,
  214. 'nb_params': nb_params,
  215. 'mode': sc.mode,
  216. 'experiment-name': 'alan',
  217. 'emb_size': sc.embedding.shape[1],
  218. 'hidden_size': sc.lstm.hidden_size,
  219. 'bidirectional': sc.lstm.bidirectional,
  220. 'dropout': sc.dropout.p,
  221. },
  222. reinit=True,
  223. )
  224. wandb.define_metric("iter")
  225. wandb.define_metric("train/*", step_metric="iter")
  226. wandb.define_metric("testiter")
  227. wandb.define_metric("test/*", step_metric="train/step")
  228.  
  229. print(f"Starting training with device {sc.embedding.device}.")
  230.  
  231. print(f"Trainable parameters: {nb_params}.")
  232.  
  233. train_data, test_data = make_data(N_total=200_000, N_train=195_000, batch_size=batch_size, t2i=t2i, unk_token_id=unk_token_id)
  234.  
  235. step = 0
  236.  
  237. test_interval = 1000
  238.  
  239. for epoch in range(n_epochs := 10):
  240. for reviews, scores in train_data:
  241. start = time()
  242. preds = sc(reviews)
  243. l = loss(preds, scores)
  244. opt.zero_grad()
  245. l.backward()
  246. opt.step()
  247.  
  248. log_dict = {}
  249.  
  250. if wandb_enabled:
  251. log_dict.update({
  252. 'train/loss': l.item(),
  253. 'train/seq_size': reviews.shape[1],
  254. 'train/time': time() - start,
  255. 'iter': step,
  256. })
  257.  
  258. if step % test_interval == 0:
  259. with torch.no_grad():
  260. sc.eval()
  261. test_losses = torch.zeros(len(test_data))
  262. test_accuracies = torch.zeros(len(test_data))
  263. test_accuracies_pos = []
  264. test_accuracies_neg = []
  265. for i, (reviews, scores) in enumerate(test_data):
  266. preds = sc(reviews)
  267. l = loss(preds, scores)
  268. test_losses[i] = l
  269. test_accuracies[i] = acc(preds, scores)
  270. acc_negative = acc_neg(preds, scores)
  271. acc_positive = acc_pos(preds, scores)
  272. if acc_negative != None:
  273. test_accuracies_neg.append(acc_negative)
  274. if acc_positive != None:
  275. test_accuracies_pos.append(acc_positive)
  276.  
  277. log_dict.update({
  278. 'test/loss': test_losses.mean().item(),
  279. 'test/acc': test_accuracies.mean().item(),
  280. 'test/acc_positive': torch.tensor(test_accuracies_pos).mean().item(),
  281. 'test/acc_negative': torch.tensor(test_accuracies_neg).mean().item(),
  282. 'testiter': step,
  283. })
  284. # TODO: Qualitative examples.
  285.  
  286. X_t, y_t = next(iter(test_data))
  287. X_t, y_t = X_t[:8], y_t[:8]
  288. y_hat = sc(X_t)
  289. # Conevrt X_t to text
  290. X_t_text = []
  291. for x in X_t:
  292. x_text = []
  293. for i in x:
  294. if i.item() == pad_token_id:
  295. continue
  296. if i.item() == unk_token_id:
  297. x_text.append("<??>")
  298. continue
  299. x_text.append(i2t[i.item()])
  300. x_text = " ".join(x_text)
  301. X_t_text.append(x_text)
  302. y_t = y_t.tolist()
  303. y_hat = y_hat.tolist()
  304. test_samples = [[X_t_text[i], y_hat[i], y_t[i]] for i in range(len(X_t_text))]
  305. log_dict.update({
  306. "test/samples": wandb.Table(
  307. data=test_samples,
  308. columns=["Review", "Predicted Label", "True Label"]
  309. )})
  310.  
  311. sc.train()
  312.  
  313. wandb.log(log_dict)
  314.  
  315. step += 1
  316.  
  317. if wandb_enabled:
  318. wandb.finish()
  319.  
  320. os.makedirs(f'models/{run_config["name"]}', exist_ok=True)
  321. with open(f'models/{run_config["name"]}/config.json', 'w') as fp:
  322. json.dump(run_config, fp)
  323. torch.save(sc.state_dict(), f'models/{run_config["name"]}/model.pt')
  324.  
  325. if __name__ == '__main__':
  326. schedule(
  327. train,
  328. backend='slurm',
  329. export="ALL",
  330. shell="/bin/sh",
  331. env=["export WANDB_SILENT=true"],
  332. )
  333. print(f"Scheduled {len(CONFIGS)} jobs.")
  334.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement