Advertisement
YaBoiSwayZ

PlotPulse v1 - Python edition (Regression Diagnostics Visualiser)

Jun 17th, 2024 (edited)
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.08 KB | Source Code | 0 0
  1. from statsmodels.regression.linear_model import RegressionResults
  2. from statsmodels.genmod.generalized_linear_model import GLMResults
  3. from statsmodels.robust.robust_linear_model import RLMResults
  4. from linearmodels.iv.results import IVResults
  5. from sklearn.linear_model import LinearRegression
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import seaborn as sns
  9. import statsmodels.api as sm
  10. from typing import Union, Tuple, Dict, Any
  11. import os
  12.  
  13. def plot(model: Union[RegressionResults, IVResults, GLMResults, RLMResults, LinearRegression],
  14.          y: np.ndarray = None,
  15.          X: np.ndarray = None,
  16.          plot_type: str = 'residual',
  17.          figsize: Tuple[int, int] = (10, 6),
  18.          color: str = 'blue',
  19.          marker: str = 'o',
  20.          save_path: str = None,
  21.          plot_params: Dict[str, Any] = None,
  22.          **kwargs):
  23.    
  24.     if plot_params is None:
  25.         plot_params = {}
  26.     elif not isinstance(plot_params, dict):
  27.         raise ValueError("plot_params must be a dictionary.")
  28.    
  29.     # Extract matplotlib and seaborn specific parameters
  30.     fig_params = plot_params.get('figure', {})
  31.     ax_params = plot_params.get('axes', {})
  32.     sns_params = plot_params.get('seaborn', {})
  33.    
  34.     def plot_residuals(fitted_values, residuals):
  35.         fig, ax = plt.subplots(figsize=figsize, **fig_params)
  36.         ax.scatter(fitted_values, residuals, color=color, marker=marker, **kwargs)
  37.         ax.axhline(0, color='red', linestyle='--')
  38.         ax.set_xlabel('Fitted values')
  39.         ax.set_ylabel('Residuals')
  40.         ax.set_title('Residuals vs Fitted values')
  41.         return fig
  42.  
  43.     def plot_qq(residuals):
  44.         fig = sm.qqplot(residuals, line='45')
  45.         plt.title('Q-Q Plot')
  46.         return fig
  47.  
  48.     def plot_leverage_resid2(model):
  49.         fig = sm.graphics.plot_leverage_resid2(model)
  50.         plt.title('Leverage vs. Residuals')
  51.         return fig
  52.  
  53.     def plot_influence(model):
  54.         fig = sm.graphics.plot_influence(model, criterion="cooks")
  55.         plt.title('Influence Plot')
  56.         return fig
  57.  
  58.     def plot_cooks(model):
  59.         fig = sm.graphics.influence_plot(model)
  60.         plt.title("Cook's Distance Plot")
  61.         return fig
  62.  
  63.     def plot_residual_density(residuals):
  64.         fig, ax = plt.subplots(figsize=figsize, **fig_params)
  65.         sns.kdeplot(residuals, shade=True, color=color, ax=ax, **sns_params, **kwargs)
  66.         ax.set_title('Residual Density Plot')
  67.         ax.set_xlabel('Residuals')
  68.         return fig
  69.    
  70.     # Map plot types to their respective functions for statsmodels models
  71.     statsmodels_plot_funcs = {
  72.         'residual': plot_residuals,
  73.         'qq': plot_qq,
  74.         'leverage': plot_leverage_resid2,
  75.         'cooks': plot_cooks,
  76.         'influence': plot_influence,
  77.         'residual_density': plot_residual_density
  78.     }
  79.  
  80.     try:
  81.         if isinstance(model, (RegressionResults, IVResults, GLMResults, RLMResults)):
  82.             if not hasattr(model, 'fittedvalues') or not hasattr(model, 'resid'):
  83.                 raise AttributeError("Model object does not have necessary attributes 'fittedvalues' or 'resid'. Ensure you are passing a valid statsmodels model.")
  84.            
  85.             fitted_values = model.fittedvalues
  86.             residuals = model.resid
  87.            
  88.             plot_func = statsmodels_plot_funcs.get(plot_type)
  89.             if plot_func:
  90.                 fig = plot_func(fitted_values, residuals)
  91.             elif plot_type == 'partial_regression':
  92.                 fig = plt.figure(figsize=figsize, **fig_params)
  93.                 sm.graphics.plot_partregress_grid(model, fig=fig, **plot_params)
  94.                 plt.title('Partial Regression Plots')
  95.             else:
  96.                 raise ValueError("Unsupported plot_type for statsmodels models. Choose from: 'residual', 'qq', 'leverage', 'cooks', 'influence', 'partial_regression', 'residual_density'.")
  97.        
  98.         elif isinstance(model, LinearRegression):
  99.             if not isinstance(y, np.ndarray) or not isinstance(X, np.ndarray):
  100.                 raise ValueError("y and X must be NumPy arrays for LinearRegression models.")
  101.            
  102.             y_pred = model.predict(X)
  103.             residuals = y - y_pred
  104.  
  105.             plot_func = statsmodels_plot_funcs.get(plot_type)
  106.             if plot_func:
  107.                 fig = plot_func(y_pred, residuals)
  108.             elif plot_type == 'leverage':
  109.                 leverage = (X * np.linalg.pinv(X.T @ X) @ X.T).sum(axis=1)
  110.                 fig, ax = plt.subplots(figsize=figsize, **fig_params)
  111.                 ax.scatter(leverage, residuals, color=color, marker=marker, **kwargs)
  112.                 ax.axhline(0, color='red', linestyle='--')
  113.                 ax.set_xlabel('Leverage')
  114.                 ax.set_ylabel('Residuals')
  115.                 ax.set_title('Leverage vs Residuals')
  116.             elif plot_type == 'partial_regression':
  117.                 raise NotImplementedError("Partial regression plots are not implemented for sklearn models.")
  118.             else:
  119.                 raise ValueError("Unsupported plot_type for LinearRegression models. Choose from: 'residual', 'qq', 'leverage', 'cooks', 'residual_density'.")
  120.        
  121.         else:
  122.             raise TypeError("Unsupported model type. Supported types are statsmodels RegressionResults, IVResults, GLMResults, RLMResults, and sklearn LinearRegression.")
  123.  
  124.         if save_path:
  125.             directory = os.path.dirname(save_path)
  126.             if not os.path.exists(directory):
  127.                 os.makedirs(directory)
  128.             fig.savefig(save_path)
  129.         else:
  130.             plt.show()
  131.  
  132.         return fig
  133.  
  134.     except ValueError as e:
  135.         print(f"ValueError: {e}. Ensure that the provided 'plot_type' is correct and that 'y' and 'X' are NumPy arrays for LinearRegression models.")
  136.     except TypeError as e:
  137.         print(f"TypeError: {e}. Supported model types are statsmodels RegressionResults, IVResults, GLMResults, RLMResults, and sklearn LinearRegression.")
  138.     except AttributeError as e:
  139.         print(f"AttributeError: {e}. Ensure the model object has the necessary attributes and is a valid statsmodels model.")
  140.     except NotImplementedError as e:
  141.         print(f"NotImplementedError: {e}. This plot type is not implemented for the provided model type.")
  142.     except Exception as e:
  143.         print(f"An unexpected error occurred: {e}")
  144.  
  145. # Usage for statsmodels or linearmodels
  146. # plot(model, plot_type='residual')  # Residual plot
  147. # plot(model, plot_type='qq')  # Q-Q plot
  148. # plot(model, plot_type='leverage')  # Leverage plot
  149. # plot(model, plot_type='cooks')  # Cook's distance plot
  150. # plot(model, plot_type='influence')  # Influence plot
  151. # plot(model, plot_type='partial_regression')  # Partial regression plot
  152. # plot(model, plot_type='residual_density')  # Residual density plot
  153.  
  154. # Usage for scikit-learn
  155. # plot(model, y, X, plot_type='residual')  # Residual plot
  156. # plot(model, y, X, plot_type='qq')  # Q-Q plot
  157. # plot(model, y, X, plot_type='leverage')  # Leverage plot
  158. # plot(model, y, X, plot_type='cooks')  # Cook's distance plot
  159. # plot(model, y, X, plot_type='residual_density')  # Residual density plot
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement