Advertisement
mirosh111000

Дерево рішень

Oct 10th, 2023
56
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.19 KB | None | 0 0
  1. import numpy as np
  2. import pandas as pd
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.tree import DecisionTreeClassifier
  5.  
  6. def gini_impurity(dataframe):
  7.     class_values = dataframe.iloc[:, -1].unique()
  8.     total_count = len(dataframe)
  9.     gini = 0.0
  10.     for class_value in class_values:
  11.         class_count = len(dataframe[dataframe.iloc[:, -1] == class_value])
  12.         class_probability = class_count / total_count
  13.         gini += (class_probability * (1 - class_probability))
  14.     return gini
  15.  
  16. def find_first_split(sorted_dataframe, index_col):
  17.     for i in range(1, len(sorted_dataframe)):
  18.         if sorted_dataframe.iloc[i, -1] != sorted_dataframe.iloc[i - 1, -1]:
  19.             first_split_value = sorted_dataframe.iloc[i, index_col]
  20.             return first_split_value
  21.  
  22. def information_gain(dataframe, feature_name, first_split_value):
  23.     total_gini_impurity = gini_impurity(dataframe)
  24.     left_subset = dataframe[dataframe[feature_name] <= first_split_value]
  25.     right_subset = dataframe[dataframe[feature_name] > first_split_value]
  26.     left_gini_impurity = gini_impurity(left_subset)
  27.     right_gini_impurity = gini_impurity(right_subset)
  28.     information_gain_value = total_gini_impurity - (len(left_subset) / len(dataframe)) * left_gini_impurity - (
  29.                 len(right_subset) / len(dataframe)) * right_gini_impurity
  30.     return information_gain_value
  31.  
  32. class Node:
  33.     def __init__(self, data, depth, col, val, class_):
  34.         self.data = data
  35.         self.depth = depth
  36.         self.col = col
  37.         self.val = val
  38.         self.class_ = class_
  39.         self.left = None
  40.         self.right = None
  41.  
  42.     @classmethod
  43.     def build_binary_tree(cls, dataframe, min_samples_split, max_depth, current_depth=0, current_col=np.NAN,
  44.                           current_val=np.NAN, current_class = np.NAN):
  45.         if len(dataframe) <= min_samples_split or current_depth >= max_depth:
  46.             if len(dataframe) > 0:
  47.                 current_class = dataframe.iloc[:, -1].value_counts().idxmax()
  48.             return cls(dataframe, current_depth, current_col, current_val, current_class)
  49.  
  50.         gain_df = pd.DataFrame({'Gain_value': [], 'Column_name': [], 'Split_value': []})
  51.         for i in range(len(dataframe.columns) - 1):
  52.             sorted_gini_df = dataframe.sort_values(by=dataframe.columns[i])
  53.             first_split_value = find_first_split(sorted_gini_df, i)
  54.             feature_name = sorted_gini_df.columns[i]
  55.             information_gain_value = information_gain(sorted_gini_df, feature_name, first_split_value)
  56.             gain_df.loc[i] = [information_gain_value, feature_name, first_split_value]
  57.  
  58.         max_gain_row_df = gain_df.loc[gain_df['Gain_value'].idxmax()]
  59.         most_common_class = dataframe.iloc[:, -1].value_counts().idxmax()
  60.  
  61.         left_subset = dataframe[dataframe[max_gain_row_df['Column_name']] <= max_gain_row_df['Split_value']]
  62.         right_subset = dataframe[dataframe[max_gain_row_df['Column_name']] > max_gain_row_df['Split_value']]
  63.         node = cls(dataframe, current_depth, max_gain_row_df['Column_name'], max_gain_row_df['Split_value'], most_common_class)
  64.         node.left = cls.build_binary_tree(left_subset, min_samples_split, max_depth, current_depth + 1)
  65.         node.right = cls.build_binary_tree(right_subset, min_samples_split, max_depth, current_depth + 1)
  66.         return node
  67.  
  68.     def train(self, min_samples_split, max_depth):
  69.  
  70.  
  71.         if (self is not None) and (not self.data.empty):
  72.             if len(self.data) <= min_samples_split or self.depth >= max_depth:
  73.                 self.data.iloc[:, -1] = self.class_
  74.  
  75.             final_df.loc[final_df.index.intersection(self.data.index)] = self.data.loc[
  76.                 final_df.index.intersection(self.data.index)]
  77.  
  78.             if (self.left is not None) and (not self.left.data.empty):
  79.                 self.left.train(min_samples_split, max_depth)
  80.             if (self.right is not None) and (not self.right.data.empty):
  81.                 self.right.train(min_samples_split, max_depth)
  82.         return self
  83.  
  84.     def predict(self, test_element):
  85.         if (self.left is None and self.right is None) or (self.val is None):
  86.             return self.data.iloc[0, -1]
  87.         if test_element[self.col] <= self.val:
  88.             return self.left.predict(test_element)
  89.         else:
  90.             return self.right.predict(test_element)
  91.  
  92.     def print_binary_tree(self, indent=""):
  93.         if (self is not None) and (not self.data.empty):
  94.  
  95.             print(indent + "Depth", self.depth)
  96.             print(indent + "Data:", len(self.data), "samples")
  97.             print(indent + "Column:", self.col,  ";  Split_Value: <=", self.val, '  (Left=True  |  Right=False)')
  98.             print(indent + "Predicted Class:", self.class_)
  99.  
  100.             if (self.left is not None) and (not self.left.data.empty):
  101.                 print(indent + "  Left:")
  102.                 self.left.print_binary_tree(indent + "    ")
  103.  
  104.             if (self.right is not None) and (not self.right.data.empty):
  105.                 print(indent + "  Right:")
  106.                 self.right.print_binary_tree(indent + "    ")
  107.  
  108. df = pd.read_csv('iris.csv')
  109. X = df.drop(labels=df.columns[-1], axis=1)
  110. Y = df[df.columns[-1]]
  111. x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=10)
  112. gini_df = pd.concat([x_train, y_train], axis=1)
  113. final_df = gini_df.copy()
  114.  
  115. min_samples_split = 5
  116. max_depth = 5
  117.  
  118. root_node = Node.build_binary_tree(gini_df, min_samples_split, max_depth)
  119. y = root_node.train(min_samples_split, max_depth)
  120.  
  121. root_node.print_binary_tree()
  122. gini_df['species_test'] = final_df['species']
  123. gini_df['Matching'] = gini_df.species == gini_df.species_test
  124. acc = round((gini_df['Matching'] == True).sum() / (len(gini_df['Matching'])) * 100, 2)
  125. gini_df.loc['Accuracy, %'] = ['' for i in range(len(gini_df.columns))]
  126. gini_df.iloc[-1, -1] = acc
  127. print(gini_df)
  128.  
  129. test_df = pd.concat([x_test, y_test], axis=1)
  130. species_test = test_df['species'].copy()
  131. for i in range(len(x_test)):
  132.     species_test.iloc[i] = root_node.predict(x_test.iloc[i])
  133.  
  134. test_df['species_test'] = species_test
  135. test_df['Matching'] = test_df.species == species_test
  136. acc1 = round((test_df['Matching'] == True).sum() / (len(test_df['Matching'])) * 100, 2)
  137. test_df.loc['Accuracy, %'] = ['' for i in range(len(test_df.columns))]
  138. test_df.iloc[-1, -1] = acc1
  139.  
  140. clf = DecisionTreeClassifier()
  141. clf.fit(x_train, y_train)
  142. accuracy_train = round(clf.score(x_train, y_train) * 100, 2)
  143. accuracy_test = round(clf.score(x_test, y_test) * 100, 2)
  144. predict_y = clf.predict(x_test)
  145. sk_match_y = predict_y == y_test
  146. predict_y = np.append(predict_y, '')
  147. sk_match_y = np.append(sk_match_y, accuracy_test)
  148. test_df['sklearn_species'] = predict_y
  149. test_df['sklearn_matching'] = sk_match_y
  150. test_df['sklearn_matching'] = test_df['sklearn_matching'].replace(0.0, False)
  151. test_df['sklearn_matching'] = test_df['sklearn_matching'].replace(1.0, True)
  152.  
  153. info_df = pd.DataFrame({'Decision Tree': [acc, acc1], 'sklearn Decision Tree': [accuracy_train, accuracy_test]},
  154.                        index=['Точність моделі на навчальних даних, %', 'Точність моделі на тестових даних, %'])
  155.  
  156. print(test_df)
  157. print(info_df)
  158.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement