diff --git a/NEWS.md b/NEWS.md index 9cc36989..fc867b8c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ingredients 0.5.0 --------------------------------------------------------------- * `feature_importance` now does `15` permutations on each variable by default. Use the `B` argument to change this number * added boxplots to `plot.feature_importance` and `plotD3.feature_importance` that showcase the permutation data +* in `aggregate_profiles`: preserve `_x_` column factor order and sort its values [#82](https://github.com/ModelOriented/ingredients/issues/82) ingredients 0.4.2 --------------------------------------------------------------- diff --git a/R/aggregate_profiles.R b/R/aggregate_profiles.R index d9222870..443e7358 100644 --- a/R/aggregate_profiles.R +++ b/R/aggregate_profiles.R @@ -318,6 +318,7 @@ aggregated_profiles_partial <- function(all_profiles, groups = NULL) { # as in https://github.com/ModelOriented/ingredients/issues/82 if (!is.numeric(all_profiles$`_x_`)) { aggregated_profiles$`_x_` <- factor(aggregated_profiles$`_x_`, levels = unique(all_profiles$`_x_`)) + aggregated_profiles <- aggregated_profiles[order(aggregated_profiles$`_x_`),] } aggregated_profiles @@ -352,6 +353,10 @@ aggregated_profiles_conditional <- function(all_profiles, groups = NULL, span = # for categorical variables we will calculate weighted average # but weights are 0-1, 1 if it's the same level and 0 otherwise split_profile$`_w_` <- split_profile$`_orginal_` == split_profile$`_x_` + + # for factors, keep proper order + # as in https://github.com/ModelOriented/ingredients/issues/82 + split_profile$`_x_` <- factor(split_profile$`_x_`, levels = unique(split_profile$`_x_`)) } per_points <- split(split_profile, split_profile[, c("_x_", groups)]) @@ -373,12 +378,6 @@ aggregated_profiles_conditional <- function(all_profiles, groups = NULL, span = aggregated_profiles$`_label_` <- paste(aggregated_profiles$`_label_`, aggregated_profiles$`_groups_`, sep = "_") } - - # for factors, keep proper order - # as in https://github.com/ModelOriented/ingredients/issues/82 - if (!is.numeric(all_profiles$`_x_`)) { - aggregated_profiles$`_x_` <- factor(aggregated_profiles$`_x_`, levels = unique(all_profiles$`_x_`)) - } aggregated_profiles$`_ids_` <- 0 aggregated_profiles diff --git a/R/plotD3_aggregated_profiles.R b/R/plotD3_aggregated_profiles.R index 6ee63e9b..c5da8579 100644 --- a/R/plotD3_aggregated_profiles.R +++ b/R/plotD3_aggregated_profiles.R @@ -84,7 +84,7 @@ plotD3.aggregated_profiles_explainer <- function(x, ..., size = 2, alpha = 1, ymax <- max(aggregated_profiles$`_yhat_`) ymin <- min(aggregated_profiles$`_yhat_`) - ymargin <- abs(ymax - ymin)*0.1; + ymargin <- abs(ymax - ymin)*0.1 aggregated_profiles_list <- split(aggregated_profiles, aggregated_profiles$`_vname_`) @@ -120,7 +120,8 @@ plotD3.aggregated_profiles_explainer <- function(x, ..., size = 2, alpha = 1, ret }) - ymean <- round(attr(x, "mean_prediction"),3) + + ymean <- ifelse("partial_dependency_explainer" %in% class(x), round(attr(x, "mean_prediction"), 3), 0) } options <- list(variableNames = as.list(all_variables), diff --git a/docs/404.html b/docs/404.html index fbb15a5c..59be3773 100644 --- a/docs/404.html +++ b/docs/404.html @@ -36,12 +36,12 @@ + - diff --git a/docs/articles/index.html b/docs/articles/index.html index dbe0ac93..ab977a79 100644 --- a/docs/articles/index.html +++ b/docs/articles/index.html @@ -36,12 +36,12 @@ + - diff --git a/docs/articles/vignette_describe.html b/docs/articles/vignette_describe.html index 99fd8c5f..1b47f769 100644 --- a/docs/articles/vignette_describe.html +++ b/docs/articles/vignette_describe.html @@ -100,7 +100,7 @@

Explanations in natural language

Adam Izdebski

-

2019-12-12

+

2019-12-17

Source: vignettes/vignette_describe.Rmd @@ -120,42 +120,42 @@

The ingredients package allows for generating prediction validation and predition perturbation explanations. They allow for both global and local model explanation.

Generic function decribe() generates a natural language description for explanations generated with feature_importance(), ceteris_paribus() functions.

To show generating automatic descriptions we first load the data set and build a random forest model classifying, which of the passangers survived sinking of the titanic. Then, using DALEX package, we generate an explainer of the model. Lastly we select a random passanger, which prediction’s should be explained.

-
library("DALEX")
-library("ingredients")
-library("randomForest")
-titanic <- na.omit(titanic)
-
-model_titanic_rf <- randomForest(survived == "yes" ~ .,
-                                 data = titanic)
-
-explain_titanic_rf <- explain(model_titanic_rf,
-                            data = titanic[,-9],
-                            y = titanic$survived == "yes",
-                            label = "Random Forest")
+
library("DALEX")
+library("ingredients")
+library("randomForest")
+titanic <- na.omit(titanic)
+
+model_titanic_rf <- randomForest(survived == "yes" ~ .,
+                                 data = titanic)
+
+explain_titanic_rf <- explain(model_titanic_rf,
+                            data = titanic[,-9],
+                            y = titanic$survived == "yes",
+                            label = "Random Forest")
#> Preparation of a new explainer is initiated
 #>   -> model label       :  Random Forest 
 #>   -> data              :  2099  rows  8  cols 
 #>   -> target variable   :  2099  values 
 #>   -> predict function  :  yhat.randomForest  will be used ( [33m default [39m )
-#>   -> predicted values  :  numerical, min =  0.004235992 , mean =  0.3234211 , max =  0.992679  
+#>   -> predicted values  :  numerical, min =  0.005387424 , mean =  0.3239327 , max =  0.9908313  
 #>   -> residual function :  difference between y and yhat ( [33m default [39m )
-#>   -> residuals         :  numerical, min =  -0.8101414 , mean =  0.001019088 , max =  0.8896769  
+#>   -> residuals         :  numerical, min =  -0.8092138 , mean =  0.0005074922 , max =  0.8898482  
 #>   -> model_info        :  package randomForest , ver. 4.6.14 , task regression ( [33m default [39m ) 
 #>  [32m A new explainer has been created! [39m
-
passanger <- titanic[sample(nrow(titanic), 1) ,-9]
-passanger
-
#>     gender age class    embarked country   fare sibsp parch
-#> 499   male  26   3rd Southampton Denmark 7.1701     1     0
+
passanger <- titanic[sample(nrow(titanic), 1) ,-9]
+passanger
+
#>      gender age            class embarked country fare sibsp parch
+#> 2003   male  26 engineering crew  Belfast England    0     0     0

Now we are ready for generating various explantions and then describing it with describe() function.

Feature Importance

Feature importance explanation shows the importance of all the model’s variables. As it is a global explanation technique, no passanger need to be specified.

-
importance_rf <- feature_importance(explain_titanic_rf)
-plot(importance_rf)
+
importance_rf <- feature_importance(explain_titanic_rf)
+plot(importance_rf)

Function describe() easily describes which variables are the most important. Argument nonsignificance_treshold as always sets the level above which variables become significant. For higher treshold, less variables will be described as significant.

-
describe(importance_rf)
+
describe(importance_rf)
#> The number of important variables for Random Forest's prediction is 65 out of 108. 
 #>  Variables _baseline_, _baseline_, _baseline_ have the highest importantance.
@@ -163,87 +163,85 @@

Ceteris Paribus Profiles

Ceteris Paribus profiles shows how the model’s input changes with the change of a specified variable.

-
perturbed_variable <- "class"
-cp_rf <- ceteris_paribus(explain_titanic_rf,
-                         passanger,
-                         variables = perturbed_variable)
-plot(cp_rf, variable_type = "categorical")
+
perturbed_variable <- "class"
+cp_rf <- ceteris_paribus(explain_titanic_rf,
+                         passanger,
+                         variables = perturbed_variable)
+plot(cp_rf, variable_type = "categorical")

For a user with no experience, interpreting the above plot may be not straightforward. Thus we generate a natural language description in order to make it easier.

-
describe(cp_rf)
-
#> For the selected instance, prediction estimated by Random Forest is equal to 0.054.
+
describe(cp_rf)
+
#> For the selected instance, prediction estimated by Random Forest is equal to 0.148.
 #> 
-#> Model's prediction would increase substantially if the value of class variable would change to "deck crew", "1st", "engineering crew", "victualling crew".
+#> Model's prediction would increase substantially if the value of class variable would change to "deck crew", "1st", "2nd".
 #> The largest change would be marked if class variable would change to "deck crew".
 #> 
-#> Other variables are with less importance and they do not change prediction by more than 0.04%.
+#> Other variables are with less importance and they do not change prediction by more than 0.07%.

Natural lannguage descriptions should be flexible in order to provide the desired level of complexity and specificity. Thus various parameters can modify the description being generated.

-
describe(cp_rf,
-         display_numbers = TRUE,
-         label = "the probability that the passanger will survive")
-
#> Random Forest predicts that for the selected instance, the probability that the passanger will survive is equal to 0.054
+
+
#> Random Forest predicts that for the selected instance, the probability that the passanger will survive is equal to 0.148
 #> 
-#> The most important change in Random Forest's prediction would occur for class = "deck crew". It increases the prediction by 0.262. 
-#> The second most important change in the prediction would occur for class = "1st". It increases the prediction by 0.167. 
-#> The third most important change in the prediction would occur for class = "engineering crew". It increases the prediction by 0.071.
+#> The most important change in Random Forest's prediction would occur for class = "deck crew". It increases the prediction by 0.451. 
+#> The second most important change in the prediction would occur for class = "1st". It increases the prediction by 0.205. 
+#> The third most important change in the prediction would occur for class = "2nd". It increases the prediction by 0.106.
 #> 
-#> Other variable values are with less importance. They do not change the the probability that the passanger will survive by more than 0.066.
+#> Other variable values are with less importance. They do not change the the probability that the passanger will survive by more than 0.048.

Please note, that describe() can handle only one variable at a time, so it is recommended to specify, which variables should be described.

-
describe(cp_rf,
-         display_numbers = TRUE,
-         label = "the probability that the passanger will survive",
-         variables = perturbed_variable)
-
#> Random Forest predicts that for the selected instance, the probability that the passanger will survive is equal to 0.054
+
+
#> Random Forest predicts that for the selected instance, the probability that the passanger will survive is equal to 0.148
 #> 
-#> The most important change in Random Forest's prediction would occur for class = "deck crew". It increases the prediction by 0.262. 
-#> The second most important change in the prediction would occur for class = "1st". It increases the prediction by 0.167. 
-#> The third most important change in the prediction would occur for class = "engineering crew". It increases the prediction by 0.071.
+#> The most important change in Random Forest's prediction would occur for class = "deck crew". It increases the prediction by 0.451. 
+#> The second most important change in the prediction would occur for class = "1st". It increases the prediction by 0.205. 
+#> The third most important change in the prediction would occur for class = "2nd". It increases the prediction by 0.106.
 #> 
-#> Other variable values are with less importance. They do not change the the probability that the passanger will survive by more than 0.066.
+#> Other variable values are with less importance. They do not change the the probability that the passanger will survive by more than 0.048.

Continuous variables are described as well.

-
perturbed_variable_continuous <- "age"
-cp_rf <- ceteris_paribus(explain_titanic_rf,
-                         passanger)
-plot(cp_rf, variables = perturbed_variable_continuous)
+
perturbed_variable_continuous <- "age"
+cp_rf <- ceteris_paribus(explain_titanic_rf,
+                         passanger)
+plot(cp_rf, variables = perturbed_variable_continuous)

-
describe(cp_rf, variables = perturbed_variable_continuous)
-
#> Random Forest predicts that for the selected instance, prediction is equal to 0.054
+
describe(cp_rf, variables = perturbed_variable_continuous)
+
#> Random Forest predicts that for the selected instance, prediction is equal to 0.148
 #> 
-#> The highest prediction occurs for (age = 4), while the lowest for (age = 27).
-#> Breakpoint is identified at (age = 13).
+#> The highest prediction occurs for (age = 2), while the lowest for (age = 45).
+#> Breakpoint is identified at (age = 9).
 #> 
-#> Average model responses are *higher* for variable values *lower* than breakpoint (= 13).
+#> Average model responses are *higher* for variable values *lower* than breakpoint (= 9).

Ceteris Paribus profiles are described only for a single observation. If we want to access the influence of more than one observation, we need to describe dependency profiles.

Partial Dependency Profiles

-
pdp <- aggregate_profiles(cp_rf, type = "partial")
-plot(pdp, variables = "fare")
+
pdp <- aggregate_profiles(cp_rf, type = "partial")
+plot(pdp, variables = "fare")

-
describe(pdp, variables = "fare")
-
#> Random Forest's mean prediction is equal to 0.054.
+
describe(pdp, variables = "fare")
+
#> Random Forest's mean prediction is equal to 0.148.
 #> 
-#> The highest prediction occurs for (fare = 110.343588), while the lowest for (fare = 14.1).
-#> Breakpoint is identified at (fare = 35.1).
+#> The highest prediction occurs for (fare = 30.1), while the lowest for (fare = 14.1).
+#> Breakpoint is identified at (fare = 27.144).
 #> 
-#> Average model responses are *higher* for variable values *higher* than breakpoint (= 35.1).
-
pdp <- aggregate_profiles(cp_rf, type = "partial", variable_type = "categorical")
-plot(pdp, variables = perturbed_variable)
+#> Average model responses are *higher* for variable values *higher* than breakpoint (= 27.144).
+
pdp <- aggregate_profiles(cp_rf, type = "partial", variable_type = "categorical")
+plot(pdp, variables = perturbed_variable)

-
describe(pdp, variables = perturbed_variable)
-
#> Random Forest's mean prediction is equal to 0.054.
+
describe(pdp, variables = perturbed_variable)
+
#> Random Forest's mean prediction is equal to 0.148.
 #> 
-#> Model's prediction would increase substantially if the value of class variable would change to "deck crew", "1st", "engineering crew".
-#> The largest change would be marked if class variable would change to "restaurant staff".
+#> Model's prediction would increase substantially if the value of class variable would change to "deck crew", "2nd", "victualling crew". On the other hand, Random Forest's prediction would decrease substantially if the value of class variable would change to "1st". The largest change would be marked if class variable would change to "victualling crew".
 #> 
-#> Other variables are with less importance and they do not change prediction by more than 0.04%.
+#> Other variables are with less importance and they do not change prediction by more than 0.07%.

Profile clustering

-
passangers <- select_sample(titanic, n = 100)
-
-sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
-clust_rf <- cluster_profiles(sp_rf, k = 3)
-head(clust_rf)
+
passangers <- select_sample(titanic, n = 100)
+
+sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
+clust_rf <- cluster_profiles(sp_rf, k = 3)
+head(clust_rf)
#> Top profiles    : 
 #>   _vname_            _label_       _x_ _cluster_    _yhat_ _ids_
-#> 1    fare Random Forest v7_1 0.0000000         1 0.1974579     0
-#> 2   sibsp Random Forest v7_1 0.0000000         1 0.1694334     0
-#> 3   parch Random Forest v7_1 0.0000000         1 0.1733480     0
-#> 4     age Random Forest v7_1 0.1666667         1 0.4792587     0
-#> 5   parch Random Forest v7_1 0.2800000         1 0.1733451     0
-#> 6   sibsp Random Forest v7_1 1.0000000         1 0.1635364     0
-
plot(sp_rf, alpha = 0.1) +
-  show_aggregated_profiles(clust_rf, color = "_label_", size = 2)
+#> 1 fare Random Forest v7_1 0.0000000 1 0.2041661 0 +#> 2 sibsp Random Forest v7_1 0.0000000 1 0.1779100 0 +#> 3 parch Random Forest v7_1 0.0000000 1 0.1819234 0 +#> 4 age Random Forest v7_1 0.1666667 1 0.4506699 0 +#> 5 parch Random Forest v7_1 0.2800000 1 0.1819217 0 +#> 6 sibsp Random Forest v7_1 1.0000000 1 0.1737733 0 +
plot(sp_rf, alpha = 0.1) +
+  show_aggregated_profiles(clust_rf, color = "_label_", size = 2)

Session info

-
sessionInfo()
-
#> R version 3.6.1 (2019-07-05)
-#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
-#> Running under: macOS Mojave 10.14.4
+
+
#> R version 3.6.0 (2019-04-26)
+#> Platform: x86_64-w64-mingw32/x64 (64-bit)
+#> Running under: Windows 10 x64 (build 17763)
 #> 
 #> Matrix products: default
-#> BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
-#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
 #> 
 #> locale:
-#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
+#> [1] LC_COLLATE=Polish_Poland.1250  LC_CTYPE=Polish_Poland.1250   
+#> [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C                  
+#> [5] LC_TIME=Polish_Poland.1250    
 #> 
 #> attached base packages:
 #> [1] stats     graphics  grDevices utils     datasets  methods   base     
@@ -290,12 +290,12 @@ 

#> [1] ingredients_0.5.0 randomForest_4.6-14 DALEX_0.4.9 #> #> loaded via a namespace (and not attached): -#> [1] Rcpp_1.0.3 compiler_3.6.1 pillar_1.4.2 tools_3.6.1 +#> [1] Rcpp_1.0.3 compiler_3.6.0 pillar_1.4.2 tools_3.6.0 #> [5] digest_0.6.23 evaluate_0.14 memoise_1.1.0 lifecycle_0.1.0 #> [9] tibble_2.1.3 gtable_0.3.0 pkgconfig_2.0.3 rlang_0.4.2 -#> [13] rstudioapi_0.10 yaml_2.2.0 pkgdown_1.4.1 xfun_0.11 -#> [17] stringr_1.4.0 dplyr_0.8.3 knitr_1.26 desc_1.2.0 -#> [21] fs_1.3.1 rprojroot_1.3-2 grid_3.6.1 tidyselect_0.2.5 +#> [13] rstudioapi_0.10 yaml_2.2.0 pkgdown_1.4.1 xfun_0.6 +#> [17] stringr_1.4.0 dplyr_0.8.3 knitr_1.22 desc_1.2.0 +#> [21] fs_1.3.1 rprojroot_1.3-2 grid_3.6.0 tidyselect_0.2.5 #> [25] glue_1.3.1 R6_2.4.1 rmarkdown_1.16 farver_2.0.1 #> [29] ggplot2_3.2.1 purrr_0.3.3 magrittr_1.5 backports_1.1.5 #> [33] scales_1.1.0 htmltools_0.4.0 MASS_7.3-51.4 assertthat_0.2.1 @@ -305,7 +305,6 @@