Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # from neural_net import AltTwoLayerNet
- tol = 1e-5
- # You should expect loss to go down and train and val accuracy go up for every epoch
- input_size = train_x.shape[1]
- hidden_size = 100
- output_size = 10
- reg = 0.1
- std = 0.001 # multiplier to init
- model = TwoLayerNet(input_size=nis, hidden_size=hidden_size, output_size=output_size, reg=reg)
- # model2 = AltTwoLayerNet(input_size=nis, hidden_size=nhs, output_size=nos, reg=reg)
- dataset = Dataset(train_x, train_y, val_x, val_y)
- # optimizers = {}
- # for name, param in model.params.items():
- # optimizers[name] = SGD()
- batch_size = 30
- np.random.seed(69)
- num_train = dataset.train_x.shape[0]
- shuffled_indices = np.arange(num_train)
- np.random.shuffle(shuffled_indices)
- sections = np.arange(batch_size, num_train, batch_size)
- batches_indices = np.array_split(shuffled_indices, sections)
- batch_indices = batches_indices[0]
- batch_x = dataset.train_x[batch_indices]
- batch_y = dataset.train_y[batch_indices]
- learning_rate = 0.05
- for i in range(10000):
- loss, grads = model.compute_loss_and_gradients(batch_x, batch_y)
- params = model.params
- if i % 1000 == 0:
- print("loss = ", loss)
- print("values = ", {k: np.sum(v) for k, v in params.items()})
- print("grads = ", {k: np.sum(v) for k, v in grads.items()})
- print()
- # print(batch_y, " > ", model.predict(batch_x))
- for name, param in model.params.items():
- grad = grads[name]
- optimizer = optimizers[name]
- params[name] = params[name] - learning_rate * grad
- # trainer = Trainer(model2, dataset, SGD(), num_epochs=1, batch_size=batch_size,
- # learning_rate=0.01, learning_rate_decay=0.9)
- # # loss_history, train_history, val_history = trainer.fit()
- # # model1.predict(train_x[0])
- # print("batches_indices[0]:", train_x[batches_indices[0]].shape)
- # model1.predict(train_x[batches_indices[0]])
- ################
- #### output ####
- ################
- loss = 2.302738695380354
- values = {'w1': -0.004068116397824662, 'b1': 0.01, 'w2': -0.0020747624855106516, 'b2': 0.001}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.305641633884892
- values = {'w1': -0.9555485071895646, 'b1': 0.014085062939531566, 'w2': 0.008299049942042487, 'b2': 0.000999999999726775}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.3174225788290426
- values = {'w1': -1.9070288979812922, 'b1': 0.018170125879062932, 'w2': 0.018672862369594933, 'b2': 0.0009999999995944364}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.3380815302128046
- values = {'w1': -2.8585092887729777, 'b1': 0.022255188818594434, 'w2': 0.029046674797153428, 'b2': 0.0009999999999514841}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.367618488036179
- values = {'w1': -3.809989679564654, 'b1': 0.026340251758126456, 'w2': 0.03942048722472072, 'b2': 0.001000000001369017}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.4060334522991647
- values = {'w1': -4.761470070356343, 'b1': 0.030425314697659713, 'w2': 0.04979429965227849, 'b2': 0.0010000000031453737}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.4533264230017626
- values = {'w1': -5.712950461148048, 'b1': 0.034510377637193, 'w2': 0.06016811207983852, 'b2': 0.0010000000049217306}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.5094974001439723
- values = {'w1': -6.664430851939748, 'b1': 0.03859544057672655, 'w2': 0.07054192450739312, 'b2': 0.0009999999981715746}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.574546383725794
- values = {'w1': -7.61591124273147, 'b1': 0.042680503516257406, 'w2': 0.08091573693496038, 'b2': 0.0009999999857370767}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
- loss = 2.648473373747228
- values = {'w1': -8.567391633523291, 'b1': 0.04676556645578911, 'w2': 0.09128954936252516, 'b2': 0.0009999999733025788}
- grads = {'w1': 0.019029607815834754, 'b1': -8.17012587906322e-05, 'w2': -0.0002074762485510651, 'b2': -1.3877787807814457e-17}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement