Advertisement
alkkofficial

Untitled

Mar 31st, 2023
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.48 KB | None | 0 0
  1. # 真零 eval engine with NN
  2. # zan1ling4 | 真零 | (pronounced Jun Ling)
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import pandas as pd
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. import copy
  10. import tqdm
  11. from sklearn.model_selection import train_test_split
  12. import chess
  13. from sklearn.preprocessing import StandardScaler
  14.  
  15. # puzzle presets
  16. MAX_MOMENTS = 20
  17. TOTAL_GAMES = 4
  18. board = chess.Board()
  19. files = ["./a.txt", "./ab.txt", "./ac.txt"]
  20.  
  21. # pre-training loop
  22.  
  23. final_df = pd.DataFrame()
  24.  
  25. # set up training data
  26.  
  27. for file in files:
  28. with open(file, "r") as f:
  29. contents = f.read()
  30.  
  31. contents = contents.split(" \n")
  32. df_add = pd.DataFrame(columns=["moves", "status"])
  33.  
  34. for game in contents:
  35. if file == "./a.txt": # NOTE: change the filename
  36. d = {"moves": game, "status": "won"}
  37. df_add.loc[len(df_add)] = d
  38. elif file == "./ab.txt":
  39. d = {"moves": game, "status": "lost"}
  40. df_add.loc[len(df_add)] = d
  41. elif file == "./ac.txt":
  42. d = {"moves": game, "status": "drew"}
  43. df_add.loc[len(df_add)] = d
  44. final_df = pd.concat([final_df, df_add], ignore_index=True)
  45. # define function that takes chess board and turns into AI-understandable matrix
  46.  
  47.  
  48. def board_data(board):
  49. board_array = np.zeros((8, 8, 13), dtype=np.int8)
  50. for i in range(64):
  51. piece = board.piece_at(i)
  52. if piece is not None:
  53. color = int(piece.color)
  54. piece_type = piece.piece_type - 1
  55. board_array[i // 8][i % 8][piece_type + 6 * color] = 1
  56. else:
  57. board_array[i // 8][i % 8][-1] = 1
  58. board_array = board_array.flatten()
  59. return board_array
  60.  
  61.  
  62. game = 0
  63. train_df = pd.DataFrame(
  64. columns=[
  65. "board_data",
  66. "status",
  67. ]
  68. )
  69.  
  70. for index, row in final_df.iterrows():
  71. moves = row["moves"].split(" ")
  72. status = row["status"]
  73. moves = [x for x in moves if x] # removes empty strings in list
  74. if len(moves) <= MAX_MOMENTS:
  75. MAX_MOMENTS = len(moves)
  76. unsampled_idx = [x for x in range(len(moves))]
  77. unsampled_idx.pop(0)
  78. for _ in range(MAX_MOMENTS - 1):
  79. board = chess.Board()
  80. up_to = np.random.choice(unsampled_idx)
  81. unsampled_idx.remove(up_to)
  82. moment = moves[:up_to]
  83. df_add = pd.DataFrame(
  84. columns=[
  85. "board_data",
  86. "status",
  87. ]
  88. )
  89. for move in moment:
  90. board.push(chess.Move.from_uci(move))
  91. ai_moves = board_data(board)
  92.  
  93. counter = 0
  94. d = {
  95. "board_data": ai_moves,
  96. "status": status,
  97. }
  98. df_add.loc[len(df_add)] = d
  99. train_df = pd.concat([train_df, df_add], ignore_index=True)
  100. game += 1
  101.  
  102. # preprocessing data
  103.  
  104. train_df["status"] = train_df["status"].map(
  105. {"won": 1.0, "lost": -1.0, "drew": 0.0}
  106. ) # target
  107. X = np.array([x for x in train_df["board_data"]])
  108. temp = ""
  109. big_l = []
  110. for x in train_df["status"]:
  111. temp = (str(x) + "2") * 832
  112. temp = temp.split("2")
  113. temp = temp[:-1]
  114. temp = [float(y) for y in temp]
  115. big_l.append(temp)
  116. train_df["status"] = big_l
  117.  
  118.  
  119. y = np.array([x for x in train_df["status"]])
  120. # train-test split for model evaluation
  121. X_train_raw, X_test_raw, y_train, y_test = train_test_split(
  122. X, y, train_size=0.7, shuffle=True
  123. )
  124. # Standardizing data
  125. scaler = StandardScaler()
  126. scaler.fit(X_train_raw)
  127. X_train = scaler.transform(X_train_raw)
  128. X_test = scaler.transform(X_test_raw)
  129. # Convert to 2D PyTorch tensors
  130.  
  131. X_train = torch.tensor(X_train, dtype=torch.float32)
  132. y_train = torch.tensor(y_train, dtype=torch.float32)
  133. X_test = torch.tensor(X_test, dtype=torch.float32)
  134. y_test = torch.tensor(y_test, dtype=torch.float32)
  135. model = nn.Sequential(
  136. nn.Linear(832, 832),
  137. nn.Linear(832, 832),
  138. nn.Linear(832, 832),
  139. nn.Linear(832, 832),
  140. nn.Tanh(),
  141. )
  142. # loss function and optimizer
  143. loss_fn = nn.MSELoss() # mean square error
  144. optimizer = optim.AdamW(model.parameters(), lr=0.0001)
  145.  
  146. n_epochs = 50 # number of epochs to run
  147. batch_size = 3 # size of each batch
  148. batch_start = torch.arange(0, len(X_train), batch_size)
  149.  
  150. # Hold the best model
  151. best_mse = np.inf # init to infinity
  152. best_weights = None
  153. history = []
  154.  
  155. for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs"):
  156. model.train()
  157. epoch_loss = 0.0
  158. for batch_idx in batch_start:
  159. batch_X, batch_y = X_train[batch_idx : batch_idx + batch_size], y_train[batch_idx : batch_idx + batch_size]
  160. optimizer.zero_grad()
  161. y_pred = model(batch_X)
  162. loss = loss_fn(y_pred, batch_y)
  163. loss.backward()
  164. optimizer.step()
  165. epoch_loss += loss.item() * batch_X.shape[0]
  166. epoch_loss /= len(X_train)
  167. history.append(epoch_loss)
  168. if epoch_loss < best_mse:
  169. best_mse = epoch_loss
  170. best_weights = copy.deepcopy(model.state_dict())
  171.  
  172. # load the best weights into the model
  173. model.load_state_dict(best_weights)
  174.  
  175. print("MSE: %.2f" % best_mse)
  176. print("RMSE: %.2f" % np.sqrt(best_mse))
  177. plt.plot(history)
  178. plt.draw()
  179. plt.savefig("ai-eval-losses.jpg")
  180. model.eval()
  181. with torch.no_grad():
  182. # Test out inference with 5 samples
  183. for i in range(5):
  184. X_sample = X_test_raw[i : i + 1]
  185. X_sample = scaler.transform(X_sample)
  186. X_sample = torch.tensor(X_sample, dtype=torch.float32)
  187. y_pred = model(X_sample)
  188. torch.save(best_weights, "ai_zan1ling4_eval.pt")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement