Advertisement
jules0707

visualize decision tree

Dec 15th, 2024 (edited)
26
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.42 KB | None | 0 0
  1. # %% [markdown]
  2. # ## Title :
  3. # Exercise: Visualizing a Decision Tree
  4. #
  5. # ## Description :
  6. # The aim of this exercise is to visualize the decision tree that is created when performing Decision Tree Classification or Regression. The tree will look similar to the one given below.
  7. #
  8. # <img src="./fig1.png" style="background-color:white;width:1300px;" >
  9. #
  10. # ## Data Description:
  11. # We are trying to predict the winner of the 2016 Presidential election (Trump vs. Clinton) in each county in the US.  To do this, we will consider several predictors including  minority: the percentage of residents that are minorities and bachelor: the percentage of resident adults with a bachelor's degree (or higher).
  12. #
  13. # ## Instructions:
  14. #
  15. # - Read the datafile `county_election_train.csv` into a Pandas data frame.
  16. # - Create the response variable based on the columns `trump` and `clinton`.
  17. # - Initialize a Decision Tree classifier of depth 3 and fit on the training data.
  18. # - Visualise the Decision Tree.
  19. #
  20. # ## Hints:
  21. #
  22. # <a href="https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html" target="_blank">sklearn.DecisionTreeClassifier()</a>Generates a Logistic Regression classifier.
  23. #
  24. # <a href="https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier.fit" target="_blank">classifier.fit()</a>Build a decision tree classifier from the training set (X, y).
  25. #
  26. # <a href="https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html" target="_blank">plt.scatter()</a>A scatter plot of y vs. x with varying marker size and/or color.
  27. #
  28. # <a href="https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.xlabel.html" target="_blank">plt.xlabel()</a>Set the label for the x-axis.
  29. #
  30. # <a href="https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.ylabel.html" target="_blank">plt.ylabel()</a>Set the label for the y-axis.
  31. #
  32. # <a href="https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html" target="_blank">plt.legend()</a>Place a legend on the Axes.
  33. #
  34. # <a href="https://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html" target="_blank">tree.plot_tree()</a>Plot a decision tree.
  35. #
  36. # **Note: This exercise is auto-graded and you can try multiple attempts. **
  37.  
  38. # %%
  39. # Import necessary libraries
  40. import numpy as np
  41. import pandas as pd
  42. import sklearn as sk
  43. import seaborn as sns
  44. from sklearn import tree
  45. import matplotlib.pyplot as plt
  46. from sklearn.tree import DecisionTreeClassifier
  47. from sklearn.model_selection import cross_val_score
  48.  
  49. pd.set_option('display.width', 100)
  50. pd.set_option('display.max_columns', 20)
  51. plt.rcParams["figure.figsize"] = (12,8)
  52.  
  53.  
  54. # %%
  55. # Read the datafile "county_election_train.csv" as a Pandas dataframe
  56. elect_train = pd.read_csv("../DATA/county_election_train.csv")
  57.  
  58. # Read the datafile "county_election_test.csv" as a Pandas dataframe
  59. elect_test = pd.read_csv("../DATA/county_election_test.csv")
  60.  
  61. # Take a quick look at the dataframe
  62. elect_train.head()
  63.  
  64.  
  65. # %%
  66. ### edTest(test_response) ###
  67.  
  68. # Creating the response variable
  69.  
  70. # Set all the rows in the train data where "trump" value is more than "clinton" as 1
  71. y_train = np.where(elect_train['trump'] > elect_train['clinton'],'1','0')
  72.  
  73.  
  74. # Set all the rows in the test data where "trump" value is more than "clinton" as 1
  75. y_test = np.where(elect_test['trump'] > elect_test['clinton'],'1','0')
  76.  
  77.  
  78. # %%
  79. # Plot "minority" vs "bachelor" as a scatter plot
  80. # Set colours blue for Trump and green for Clinton
  81.  
  82. vote_train = pd.DataFrame(y_train.reshape(-1,1),columns=['won'])
  83.  
  84. # Your code here
  85. plt.scatter(elect_train['minority'],elect_train['bachelor'],c=['r' if vote_result == '1' else 'b' for vote_result in vote_train['won']])
  86.  
  87. vote_train.head()
  88. vote_train.shape
  89.  
  90. # %%
  91. # Initialize a Decision Tree classifier of depth 3 and choose
  92. # splitting criteria to be the gini
  93. dtree = DecisionTreeClassifier(max_depth=3,criterion='gini')
  94.  
  95. # Fit the classifier on the train data
  96. # but only use the minority column as the predictor variable
  97. x = elect_train[['minority']]
  98. y = vote_train['won']
  99. dtree.fit(x,y)
  100.  
  101.  
  102. # %% [markdown]
  103. #
  104.  
  105. # %%
  106. # Code to set the size of the plot
  107. plt.figure(figsize=(30,20))
  108.  
  109. # Plot the Decision Tree trained above with parameters filled as True
  110. tree.plot_tree(decision_tree=dtree,filled=True,impurity=True,node_ids=True,proportion=True,rounded=True)
  111.  
  112.  
  113. plt.show();
  114.  
  115.  
  116.  
  117.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement