Code
library("ggplot2")
library("reshape2")
library("SISIR")
This practical’s aim is to analyze the data described in (Baragatti et al. 2019) using random forest with interval selection. It reproduces part of the results described in (Servien and Vialaneix 2023) and available in this public repository.
Be sure that you have downloaded the data as described in the directory ../data/README.md
before you start!
Be also sure to have the proper R packages installed (usage of renv
is strongly recommended):
library("ggplot2")
library("reshape2")
library("SISIR")
<- "P"
pvar <- sprintf("../data/data_truffles_%s.rda", pvar)
input_file load(input_file)
x
contains the weather data (here, rainfall):
dim(x)
[1] 25 15
head(x)
Y
contains the corresponding truffle yield:
Y
[1] 1.000 57.100 0.500 22.500 0.450 5.000 31.290 24.800 54.550 0.000
[11] 33.400 5.000 0.000 0.000 0.000 62.600 23.050 33.560 0.000 27.125
[21] 0.000 0.000 0.000 0.000 19.950
and beta
contains the ground truth about important months where rainfall impacts the truffle yield:
beta
[1] 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0
The evolution of the yield over the years is obtained with:
<- data.frame(yield = Y, year = rownames(x))
df <- ggplot(df, aes(x = year, y = Y)) + geom_point() + theme_bw() +
p ylab("yield") + xlab("year") +
theme(axis.text.x = element_text(angle = 90))
p
The weather data can be visualized with (color scale corresponds to the yield):
<- x
df $year <- rownames(df)
df$yield <- Y
df<- melt(df, id.vars = c("yield", "year"))
df <- names(x)[as.logical(beta)][1]
start_s <- names(x)[as.logical(beta)][sum(beta)]
end_s <- ggplot(df, aes(x = variable, y = value, colour = yield, group = year)) +
p geom_line() + theme_bw() + ylab("rainfall") + xlab("month") +
geom_segment(x = start_s, y = 300, xend = end_s, yend = 300,
color = "darkred") +
theme(axis.text.x = element_blank())
p
We first create a dataset containing all the combinations of the different versions of the three steps of the method:
<- c("adjclust", "cclustofvar")
group_methods <- c("pls", "basics", "cclustofvar")
summary_methods <- c("none", "boruta", "relief")
selection_methods <- expand.grid(selection_methods, summary_methods, group_methods,
sfcb_variants stringsAsFactors = FALSE)
names(sfcb_variants) <- c("selection", "summary", "group")
# removing options that are not really compatible
<- sfcb_variants$group == "adjclust" &
to_remove $summary == "cclustofvar"
sfcb_variants<- sfcb_variants[to_remove, ]
removed <- sfcb_variants[!to_remove, ]
sfcb_variants sfcb_variants
Then, for all these variants, we run the SFCB method:
<- lapply(1:nrow(sfcb_variants), function(rnum) {
res_SFCB <- sfcb(x, Y, group.method = sfcb_variants$group[rnum],
out summary.method = sfcb_variants$summary[rnum],
selection.method = sfcb_variants$selection[rnum],
seed = 123, at = 5)
$beta <- beta
outreturn(out)
})
We select one of the results (the one using adjclust for the grouping, PLS to compute summaries and Boruta to select variables):
<- sfcb_variants$group == "adjclust" & sfcb_variants$summary == "pls" &
selected $selection == "boruta"
sfcb_variants<- res_SFCB[[which(selected)]]
cur_res cur_res
Call:
sfcb(X = x, Y = Y, group.method = sfcb_variants$group[rnum],
summary.method = sfcb_variants$summary[rnum], selection.method = sfcb_variants$selection[rnum],
at = 5, seed = 123)
SFCB object with:
- 5 interval(s)
- 2 selected interval(s)
- 5 repeats
- MSE ranging in [272.8621, 285.8808]
- computational time (total): 0.94 (seconds)
The selected intervals can be displayed along with the clustering of time points with:
plot(cur_res)
Reversals detected in the dendrogram. Rectangles are not relevant and thus, they are not displayed.
As ground truth is available, quality criteria can be computed:
quality(cur_res, beta)
Call:
sfcb(X = x, Y = Y, group.method = sfcb_variants$group[rnum],
summary.method = sfcb_variants$summary[rnum], selection.method = sfcb_variants$selection[rnum],
at = 5, seed = 123)
SFCB object with:
- 5 interval(s)
- 2 selected interval(s)
- 5 repeats
- MSE ranging in [272.8621, 285.8808]
- computational time (total): 0.94 (seconds)
- precision wrt ground truth in [0.4285714, 0.4285714]
- recall wrt ground truth in [0.6, 0.6]
Now, let’s create a function that outputs quality criteria for a given SFCB
object:
<- function(res_SFCB) {
compute_SFCB_qualities if ("selected" %in% names(res_SFCB)) {
# computing Rand index, Precision, Recall
<- quality(res_SFCB, res_SFCB$beta)
res_SFCB $ARI <- res_SFCB$quality$ARI
res_SFCB$Precision <- res_SFCB$quality$Precision
res_SFCB$Recall <- res_SFCB$quality$Recall
res_SFCBelse {
} # overwise NA
$ARI <- NA
res_SFCB$Precision <- NA
res_SFCB$Recall <- NA
res_SFCB
}
# computing MSE
$bmse <- min(res_SFCB$mse$mse)
res_SFCB
# computing computational time
$time <- sum(res_SFCB$computational.times)
res_SFCB
return(res_SFCB)
}
We can apply it to all results and gather the results in a clean way:
<- lapply(res_SFCB, compute_SFCB_qualities)
res_SFCB $mse <- sapply(res_SFCB, "[[", "bmse")
sfcb_variants$Precision <- sapply(res_SFCB, "[[", "Precision")
sfcb_variants$Recall <- sapply(res_SFCB, "[[", "Recall")
sfcb_variants$ARI <- sapply(res_SFCB, "[[", "ARI")
sfcb_variants$ct <- sapply(res_SFCB, "[[", "time")
sfcb_variantsnames(sfcb_variants)[4:8] <- c("mse", "precision", "recall", "adjusted Rand",
"computational time")
c(3:1, 4:8)] sfcb_variants[
Some additional clean-ups (mostly adding NA
for missing variants):
names(sfcb_variants)[7:8] <- c("arand", "time")
$selection <- factor(sfcb_variants$selection,
sfcb_variantslevels = c("none", "boruta", "relief"),
ordered = TRUE)
<- data.frame("mse" = rep(NA, nrow(removed)),
fill_removed "precision" = rep(NA, nrow(removed)),
"recall" = rep(NA, nrow(removed)),
"arand" = rep(NA, nrow(removed)),
"time" = rep(NA, nrow(removed)))
<- data.frame(removed, fill_removed)
removed <- rbind(sfcb_variants, removed) sfcb_variants
Comparison for Precision / Recall:
<- sfcb_variants[sfcb_variants$selection != "none", ]
cur_res $selection <- factor(cur_res$selection, levels = c("boruta", "relief"),
cur_resordered = TRUE)
<- ggplot(cur_res, aes(x = precision, y = recall, colour = group,
p shape = summary)) +
geom_point(size = 2) + theme_bw() + ylim(0, 1) +
scale_x_continuous(breaks = c(0, 0.5, 1), labels = c("0", "0.5", "1"),
limits = c(0, 1))
p
Comparison for F\(_1\) score:
$f1 <- cur_res$precision * cur_res$recall * 2 /
cur_res$precision + cur_res$recall)
(cur_res<- ggplot(cur_res, aes(x = summary, y = f1, fill = group)) +
p geom_bar(stat = "identity", position = "dodge") +
facet_grid(~ selection) + theme_bw() + ylab(expression(F[1] ~" score")) +
theme(axis.title.x = element_blank())
p
Note: The empty bar for P prediction with cclustofvar/PLS/relief corresponds to the case where precision and recall are zero (hence, the F\(_1\) score is NA
).
Comparison for mean square error:
<- ggplot(sfcb_variants, aes(x = summary, y = mse, fill = group)) +
p geom_bar(stat = "identity", position = "dodge") +
facet_grid(~ selection) + theme_bw() + ylab("Mean square error") +
theme(axis.title.x = element_blank())
p
Comparison for computation time:
<- ggplot(sfcb_variants, aes(x = summary, y = time, fill = group)) +
p geom_bar(stat = "identity", position = "dodge") +
facet_grid(~ selection) + theme_bw() +
theme(axis.title.x = element_blank()) + ylab("Computation time")
p
This work is licensed under a Creative Commons Attribution 4.0 International License .
The code is distributed under GPL-3 licence.
sessionInfo()
R version 4.3.1 (2023-06-16)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so; LAPACK version 3.10.0
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=fr_FR.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=fr_FR.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fr_FR.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=fr_FR.UTF-8 LC_IDENTIFICATION=C
time zone: Europe/Paris
tzcode source: system (glibc)
attached base packages:
[1] parallel stats graphics grDevices utils datasets methods
[8] base
other attached packages:
[1] SISIR_0.2.2 doParallel_1.0.17 iterators_1.0.14 foreach_1.5.2
[5] reshape2_1.4.4 ggplot2_3.4.3
loaded via a namespace (and not attached):
[1] ellipse_0.5.0 gtable_0.3.4 capushe_1.1.1
[4] shape_1.4.6 xfun_0.40 ggrepel_0.9.3
[7] lattice_0.21-8 vctrs_0.6.3 tools_4.3.1
[10] generics_0.1.3 tibble_3.2.1 fansi_1.0.4
[13] cluster_2.1.4 rARPACK_0.11-0 pkgconfig_2.0.3
[16] Matrix_1.6-0 RColorBrewer_1.1-3 mixOmics_6.24.0
[19] sparseMatrixStats_1.12.2 lifecycle_1.0.3 farver_2.1.1
[22] compiler_4.3.1 stringr_1.5.0 munsell_0.5.0
[25] aricode_1.0.2 codetools_0.2-19 Boruta_8.0.0
[28] htmltools_0.5.6 yaml_2.3.7 glmnet_4.1-8
[31] tidyr_1.3.0 pillar_1.9.0 MASS_7.3-60
[34] BiocParallel_1.34.2 CORElearn_1.57.3 viridis_0.6.4
[37] rpart_4.1.19 RSpectra_0.16-1 tidyselect_1.2.0
[40] digest_0.6.33 stringi_1.7.12 purrr_1.0.2
[43] dplyr_1.1.3 labeling_0.4.3 splines_4.3.1
[46] adjclust_0.6.7 fastmap_1.1.1 grid_4.3.1
[49] colorspace_2.1-0 expm_0.999-7 cli_3.6.1
[52] magrittr_2.0.3 survival_3.5-5 utf8_1.2.3
[55] corpcor_1.6.10 withr_2.5.0 scales_1.2.1
[58] plotrix_3.8-2 rmarkdown_2.24 matrixStats_1.0.0
[61] igraph_1.5.1 rpart.plot_3.1.1 nnet_7.3-19
[64] gridExtra_2.3 ranger_0.15.1 evaluate_0.21
[67] knitr_1.43 viridisLite_0.4.2 rlang_1.1.1
[70] Rcpp_1.0.11 dendextend_1.17.1 glue_1.6.2
[73] jsonlite_1.8.7 R6_2.5.1 plyr_1.8.8
[76] MatrixGenerics_1.12.3