Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # swap_closest_weight.py
- def swap_closest_weight(weights, incorrect_pred):
- """
- Swap an incorrect weight with the weight closest to it which is above a certain threshold
- """
- threshold = 5 # Set a threshold for weight updates
- delta = float('inf')
- current_weight = weights[incorrect_pred]
- for weight in weights:
- if weights[weight] > threshold and weight != incorrect_pred:
- new_delta = abs(weights[weight] - current_weight)
- if new_delta < delta:
- delta = new_delta
- closest_weight = weight
- weights[incorrect_pred] -= 1
- weights[closest_weight] += 1
- def guess(data, weights):
- """
- Returns the predicted label and whether it is correct, given the input data
- """
- percentages = {}
- for i in range(10):
- total_weight = sum([weights[node, str(i)] for node in nodes if str(i) in node])
- percentages[str(i)] = total_weight
- try:
- total = 100 / sum(percentages.values())
- except:
- return False, 'X'
- percentages = {k: round(v*total, 6) for k, v in percentages.items()}
- sorted_ppp = sorted(percentages.items(), key=lambda x: x[1], reverse=True)
- ppp = sorted_ppp[0][0]
- correct = data[-1] == ppp
- if not correct:
- swap_closest_weight(weights, weights[data[-1]])
- return correct, ppp
- def pattern_recognition(pi):
- """
- Perform pattern recognition on a string of digits
- """
- weights = {(node, i): 1 for node in nodes for i in range(10)}
- right_answers = [(pi[i], str(i).zfill(4)) for i in range(len(pi))]
- random.shuffle(right_answers)
- correct_count = 0
- incorrect_count = 0
- threshold = 0.1
- learning_rate = 0.1
- for i, (digit, data) in enumerate(right_answers):
- prediction_correct, prediction = guess(data, weights)
- if prediction_correct:
- correct_count += 1
- incorrect_count = 0
- else:
- incorrect_count += 1
- if incorrect_count > 1 and random.random() > threshold:
- swap_closest_weight(weights, weights[digit])
- else:
- weights[digit] -= learning_rate
- weights[prediction] += learning_rate
- if correct_count + incorrect_count == len(right_answers):
- break
- return weights
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement