Advertisement
YaBoiSwayZ

PlotPulse v2 - R edition (Regression Diagnostics Visualiser)

Jul 2nd, 2024
474
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 8.55 KB | Source Code | 0 0
  1. library(ggplot2)
  2. library(plotly)
  3. library(randomForest)
  4. library(e1071)
  5. library(rpart)
  6. library(pdp)
  7.  
  8. # Helper function to log messages
  9. log_message <- function(message, verbose = TRUE) {
  10.   if (verbose) {
  11.     cat(sprintf("[%s] %s\n", Sys.time(), message))
  12.   }
  13. }
  14.  
  15. # Helper function to save the plot
  16. save_plot <- function(save_path, create_dir, figsize) {
  17.   dir_name <- dirname(save_path)
  18.   if (!dir.exists(dir_name)) {
  19.     if (create_dir) {
  20.       dir.create(dir_name, recursive = TRUE)
  21.       log_message(sprintf("Directory created: %s", dir_name))
  22.     } else {
  23.       warning("The directory specified in save_path does not exist. The plot will not be saved.")
  24.       log_message("Failed to save plot - directory does not exist")
  25.       return(FALSE)
  26.     }
  27.   }
  28.   dev.copy2pdf(file = save_path, width = figsize[1], height = figsize[2])
  29.   dev.off()
  30.   log_message(sprintf("Plot saved to %s", save_path))
  31.   return(TRUE)
  32. }
  33.  
  34. # Helper function for decision boundary plot
  35. plot_decision_boundary <- function(model, X, y, color = "blue", plot_params = list()) {
  36.   grid_range <- apply(X, 2, range)
  37.   grid <- expand.grid(
  38.     X1 = seq(grid_range[1, 1], grid_range[2, 1], length.out = 200),
  39.     X2 = seq(grid_range[1, 2], grid_range[2, 2], length.out = 200)
  40.   )
  41.   grid$Prediction <- predict(model, grid)
  42.  
  43.   p <- ggplot(grid, aes(x = X1, y = X2)) +
  44.     geom_tile(aes(fill = Prediction), alpha = 0.3) +
  45.     geom_point(data = as.data.frame(X), aes(x = X[, 1], y = X[, 2], color = y)) +
  46.     scale_color_manual(values = plot_params$boundary_colors) +
  47.     ggtitle("Decision Boundary") +
  48.     labs(x = plot_params$xlab, y = plot_params$ylab) +
  49.     plot_params$theme
  50.  
  51.   if (plot_params$contour) {
  52.     p <- p + geom_contour(aes(z = as.numeric(Prediction)), color = "black")
  53.   }
  54.  
  55.   return(p)
  56. }
  57.  
  58. # Main plotting function
  59. plot_model <- function(
  60.   model,
  61.   y = NULL,
  62.   X = NULL,
  63.   plot_type = "residual",
  64.   figsize = c(10, 6),
  65.   color = "blue",
  66.   save_path = NULL,
  67.   plot_params = list(),
  68.   interactive = FALSE,
  69.   verbose = TRUE,
  70.   create_dir = FALSE,
  71.   ...
  72. ) {
  73.   log_message("Starting plot_model function", verbose)
  74.  
  75.   if (!inherits(model, c("lm", "glm", "lme4", "randomForest", "svm", "rpart"))) {
  76.     stop("This function supports objects of class 'lm', 'glm', 'lme4', 'randomForest', 'svm', and 'rpart'. Please provide a valid model object.")
  77.   }
  78.  
  79.   valid_plot_types <- c("residual", "qq", "scale_location", "cooks", "residual_leverage", "cooks_leverage", "partial_dependence", "decision_boundary")
  80.   if (!plot_type %in% valid_plot_types) {
  81.     stop(paste("Unsupported plot_type. Choose from:", paste(valid_plot_types, collapse = ", ")))
  82.   }
  83.  
  84.   if (!is.numeric(figsize) || length(figsize) != 2) {
  85.     stop("figsize must be a numeric vector of length 2.")
  86.   }
  87.  
  88.   if (!is.character(color) || length(color) != 1) {
  89.     stop("color must be a single character string.")
  90.   }
  91.  
  92.   default_params <- list(
  93.     main = "",
  94.     xlab = "X-axis",
  95.     ylab = "Y-axis",
  96.     theme = theme_minimal(),
  97.     boundary_colors = c("red", "blue"),
  98.     contour = TRUE
  99.   )
  100.  
  101.   plot_params <- modifyList(default_params, plot_params)
  102.  
  103.   old_par <- par(no.readonly = TRUE)
  104.   on.exit(par(old_par), add = TRUE)
  105.  
  106.   tryCatch({
  107.     plot_data <- data.frame(
  108.       Fitted = if (inherits(model, "randomForest")) predict(model, X) else fitted(model),
  109.       Residuals = if (inherits(model, "randomForest")) y - predict(model, X) else residuals(model),
  110.       StdResiduals = if (inherits(model, c("lm", "glm", "lme4"))) rstandard(model) else NULL,
  111.       Leverage = if (inherits(model, c("lm", "glm", "lme4"))) hatvalues(model) else NULL,
  112.       CookD = if (inherits(model, c("lm", "glm", "lme4"))) cooks.distance(model) else NULL
  113.     )
  114.    
  115.     if (!interactive) {
  116.       par(mfrow = c(2, 2), oma = c(0, 0, 2, 0))
  117.      
  118.       if (plot_type == "partial_dependence" && inherits(model, "randomForest")) {
  119.         pd <- partial(model, pred.var = names(X), grid.resolution = 50, plot = TRUE)
  120.         if (!is.null(save_path) && save_plot(save_path, create_dir, figsize)) {
  121.           log_message("Returning plot object", verbose)
  122.           return(invisible(recordPlot()))
  123.         }
  124.        
  125.       } else if (plot_type == "decision_boundary" && inherits(model, "svm")) {
  126.         plot(model, X, y)
  127.         if (!is.null(save_path) && save_plot(save_path, create_dir, figsize)) {
  128.           log_message("Returning plot object", verbose)
  129.           return(invisible(recordPlot()))
  130.         }
  131.        
  132.       } else {
  133.         plot_number <- switch(
  134.           plot_type,
  135.           "residual" = 1,
  136.           "qq" = 2,
  137.           "scale_location" = 3,
  138.           "cooks" = 4,
  139.           "residual_leverage" = 5,
  140.           "cooks_leverage" = 6
  141.         )
  142.        
  143.         plot(
  144.           model,
  145.           which = plot_number,
  146.           main = plot_params$main,
  147.           xlab = plot_params$xlab,
  148.           ylab = plot_params$ylab,
  149.           ...
  150.         )
  151.        
  152.         if (!is.null(save_path) && save_plot(save_path, create_dir, figsize)) {
  153.           log_message("Returning plot object", verbose)
  154.           return(invisible(recordPlot()))
  155.         }
  156.       }
  157.      
  158.     } else {
  159.       log_message("Creating interactive plot", verbose)
  160.      
  161.       interactive_plot <- switch(
  162.         plot_type,
  163.         "residual" = ggplot(plot_data, aes(Fitted, Residuals)) +
  164.           geom_point(color = color, ...) +
  165.           ggtitle("Residual Plot") +
  166.           labs(x = plot_params$xlab, y = plot_params$ylab) +
  167.           plot_params$theme,
  168.         "qq" = ggplot(plot_data, aes(sample = StdResiduals)) +
  169.           geom_qq(color = color, ...) +
  170.           geom_qq_line(color = "red") +
  171.           ggtitle("QQ Plot") +
  172.           labs(x = plot_params$xlab, y = plot_params$ylab) +
  173.           plot_params$theme,
  174.         "scale_location" = ggplot(plot_data, aes(Fitted, sqrt(abs(StdResiduals)))) +
  175.           geom_point(color = color, ...) +
  176.           ggtitle("Scale-Location Plot") +
  177.           labs(x = plot_params$xlab, y = plot_params$ylab) +
  178.           plot_params$theme,
  179.         "cooks" = ggplot(plot_data, aes(seq_along(CookD), CookD)) +
  180.           geom_bar(stat = "identity", color = color, ...) +
  181.           ggtitle("Cook's Distance Plot") +
  182.           labs(x = plot_params$xlab, y = plot_params$ylab) +
  183.           plot_params$theme,
  184.         "residual_leverage" = ggplot(plot_data, aes(Leverage, StdResiduals)) +
  185.           geom_point(color = color, ...) +
  186.           ggtitle("Residuals vs Leverage Plot") +
  187.           labs(x = plot_params$xlab, y = plot_params$ylab) +
  188.           plot_params$theme,
  189.         "cooks_leverage" = ggplot(plot_data, aes(Leverage, CookD)) +
  190.           geom_point(color = color, ...) +
  191.           ggtitle("Cook's Distance vs Leverage Plot") +
  192.           labs(x = plot_params$xlab, y = plot_params$ylab) +
  193.           plot_params$theme,
  194.         "partial_dependence" = if (inherits(model, "randomForest")) {
  195.           pd <- partial(model, pred.var = names(X), grid.resolution = 50)
  196.           ggplot(pd, aes(x = X, y = yhat)) +
  197.             geom_line(color = color, ...) +
  198.             ggtitle("Partial Dependence Plot") +
  199.             labs(x = plot_params$xlab, y = plot_params$ylab) +
  200.             plot_params$theme
  201.         } else {
  202.           stop("Partial dependence plots are only supported for 'randomForest' models.")
  203.         },
  204.         "decision_boundary" = if (inherits(model, "svm")) {
  205.           plot_decision_boundary(model, X, y, color = color, plot_params = plot_params, ...)
  206.         } else {
  207.           stop("Decision boundary plots are only supported for 'svm' models.")
  208.         }
  209.       )
  210.      
  211.       interactive_plot <- ggplotly(interactive_plot, tooltip = c("x", "y", "text"))
  212.       log_message("Returning interactive plot object", verbose)
  213.       return(interactive_plot)
  214.     }
  215.   }, error = function(e) {
  216.     log_message(sprintf("An error occurred: %s", e$message), verbose)
  217.     stop(e)
  218.   })
  219.  
  220.   log_message("plot_model function completed", verbose)
  221. }
  222.  
  223. # Usage
  224. # model <- lm(mpg ~ wt, data = mtcars)
  225. # plot_model(model, plot_type = "residual")
  226. # plot_model(model, plot_type = "qq")
  227. # plot_model(model, plot_type = "scale_location")
  228. # plot_model(model, plot_type = "cooks")
  229. # plot_model(model, plot_type = "residual_leverage")
  230. # plot_model(model, plot_type = "cooks_leverage")
  231. # model_rf <- randomForest(mpg ~ wt + hp, data = mtcars)
  232. # plot_model(model_rf, X = mtcars[, c("wt", "hp")], plot_type = "partial_dependence")
  233. # model_svm <- svm(Species ~ ., data = iris)
  234. # plot_model(model_svm, X = iris[, -5], y = iris$Species, plot_type = "decision_boundary")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement