Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # hopfield_solution.py
- import pdb
- import math
- import random
- import os, sys
- import ast
- def print_each_at_x(vars_list, nnn=22):
- sss = ''
- for var in vars_list:
- n = nnn
- if isinstance(var, list):
- var = ' '.join([str(v) for v in var])
- else:
- var = str(var)
- if len(var) > n - 4:
- n *= 2
- var = (var + ' ' * n)[:n-1] + ' '
- sss += var
- print(sss.rstrip())
- # Hyperparameters
- LEARNING_RATE = 0.0001
- EPOCHS = 200
- THRESHOLD = 0.8
- # Direct updates towards the correct prediction
- def direct_update(weights, node, pi_digit, ppp):
- delta = (1.0 - weights[node, pi_digit] * (weights[node, ppp] - (pi_digit == ppp))) * LEARNING_RATE
- weights[node, ppp] += delta
- return weights
- # Update the weights using the heaviest weight swapping
- def update_weight(weights, nodes, pi_digit, data, ppp, recall_pattern):
- recall_pattern.append(ppp)
- if ppp != pi_digit and len(recall_pattern) >= 2:
- heaviest_node = None
- max_weight = -1
- for node in nodes:
- if node[0] in recall_pattern and node[1] not in recall_pattern:
- if weights[node, pi_digit] > max_weight:
- max_weight = weights[node, pi_digit]
- heaviest_node = node
- if heaviest_node != None:
- weights = direct_update(weights, heaviest_node, pi_digit, ppp)
- recall_pattern.pop(0)
- return weights, recall_pattern
- # Initialize dictionary for defining weights and targets
- def initialize_dict(digits, nodes):
- weights = {}
- target = {}
- for digit in digits:
- for node in nodes:
- weights[node, digit] = 1
- target[node, digit] = 0
- return weights, target
- def guess(data, weights, nodes, target):
- ddd = {str(i): 0 for i in range(10)}
- percentages = {}
- for node in nodes:
- ddd[str(target[node, pi_digit])] += weights[node, pi_digit]
- for i in range(10):
- percentages[str(i)] = round(ddd[str(i)] * (100.0/sum(ddd.values())), 6)
- ppp = max(percentages, key=percentages.get)
- if percentages[ppp] < THRESHOLD*100:
- ppp = 'X'
- prediction = ppp == pi_digit
- return prediction, ppp
- def train(weights, target, nodes, recall_pattern):
- right = 0
- wrong = 0
- for epoch in range(EPOCHS):
- random.shuffle(data)
- for digit, pixels in data:
- pi_digit = str(digit)
- recall_pattern = []
- for i, pixel in enumerate(pixels):
- ppp = '1' if pixel else '-1'
- weights, recall_pattern = update_weight(weights, nodes, pi_digit, data, ppp, recall_pattern)
- prediction, ppp = guess(data, weights, nodes, target)
- if prediction:
- right += 1
- else:
- wrong += 1
- if epoch % 10 == 0:
- print_each_at_x(["epoch", epoch, "accuracy", round(100.0*right/(right+wrong),2), "X=", ppp, "perc", percentages])
- sys.stdout.flush()
- digits = [str(i) for i in range(10)]
- nodes = [(i, j) for i in range(28*28) for j in digits]
- weights, target = initialize_dict(digits, nodes)
- # Train the network
- train(weights, target, nodes, [])
- # Save the trained weights to a file
- with open('weights.txt', 'w') as f:
- f.write(str(weights))
- # Load the trained weights from a file
- with open('weights.txt', 'r') as f:
- weights = ast.literal_eval(f.read())
- # Example usage:
- # test_data = [(8, [0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1])]
- # prediction, ppp = guess(test_data, weights, nodes, target)
- # print(prediction, ppp)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement