-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCaretTutorial.Rmd
187 lines (149 loc) · 4.68 KB
/
CaretTutorial.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
---
title: "Using Caret to solve classification problems"
output:
html_document:
hightlight: tango
theme: united
toc: yes
toc_depth: 3
toc_float:
collapsed: no
smooth_scroll: yes
---
#Introduction
Classification or unsupervised clustering - have observations with several variables including a class label - wish to make a predictive model to accurately predict the class of new observations.
Many applications, many different methods = many different packages in R.
The caret package wraps many implementations of classification methods into a common framework to standarise the input and output. Useful for comparing performance of different methods on the same data.
See https://topepo.github.io/caret/index.html for very detailed documentation.
#Load the libraries
```{r message=F}
library(DT)
library(ggplot2)
library(cowplot)
library(caret)
library(gbm)
library(randomForest)
library(nnet)
```
If not on windows, can use the doMC package to run the code in paralel
```{r,message=F}
library(doMC)
registerDoMC(cores = 5)
```
#Exploratory analysis
###How many variables and observations?
```{r}
dim(iris)
```
###Get a feel for the data
```{r}
datatable(iris,rownames = F)
summary(iris)
```
###PCA
```{r warning=FALSE}
#perform a PCA on the data in assay(x) for the selected genes
pca <- prcomp(iris[,-5])
# the contribution to the total variance for each component
percentVar <- pca$sdev^2 / sum( pca$sdev^2 )
# assemble the data for the plot
d <- data.frame(PC1=pca$x[,1], PC2=pca$x[,2],species=as.factor(iris$Species))
ggplot(data=d, aes(x=PC1, y=PC2, colour=species)) + geom_point(size=3) + xlab(paste0("PC1: ",round(percentVar[1] * 100),"% variance")) + ylab(paste0("PC2: ",round(percentVar[2] * 100),"% variance")) + coord_fixed() + scale_colour_discrete(name="Species") + theme_cowplot()
```
###What features are likely to be discriminative?
```{r}
featurePlot(iris[,1:4],as.factor(iris[,5]), plot="boxplot")
```
#Create training and test datasets
The createDataPartition splits the data into a test and training dataset. We will use the training set to optimise the predictive model while the test dataset can be used to simulate new data.
```{r}
set.seed(42)
trainIndex <- createDataPartition(iris$Species, p = .8,
list = FALSE,
times = 1)
irisTrain <- iris[ trainIndex,]
irisTest <- iris[-trainIndex,]
table(irisTest$Species)
```
#Train a model
Can use the trainControl function to set up the cross-validation method. We optimise the the model by choosing the best parameters within the training dataset. The dataset is split into 10 parts/ 9 parts are used for training the model with a given set of parameters and one part for testing the performance.
```{r}
# 10-fold CV repeated 10 times
fitControl <- trainControl(method = "repeatedcv",
number = 10,
repeats = 50)
```
Get the parameters for the knn model
```{r}
getModelInfo("knn",regex = F)
```
Set up the search grid for the model
```{r}
knnGrid <- expand.grid(k=1:10)
```
Train the model
```{r}
set.seed(42)
knnFit <- train(Species~ ., data = irisTrain,
method = "knn",
trControl = fitControl,
tuneGrid=knnGrid
)
knnFit
```
Plot the resampling profiles
```{r}
ggplot(knnFit)
ggplot(knnFit,metric = "Kappa")
```
#Make predictions using the best model
```{r}
#Predict the samples in the training set with the best model
confusionMatrix(knnFit)
```
#Train other models for comparison
```{r, message=F}
#RandomForest
set.seed(42)
rfFit <- train(Species~ ., data = irisTrain,
method = "rf",
trControl = fitControl
)
rfFit
#Support Vector Machine
set.seed(42)
svmGrid <- expand.grid(C= 2^c(0:5))
svmFit <- train(Species~ ., data = irisTrain,
method = "svmLinear",
trControl = fitControl,
tuneGrid=svmGrid
)
svmFit
#Gradient Boosted Model
set.seed(42)
gbmFit <- train(Species~ ., data = irisTrain,
method = "gbm",
trControl = fitControl,
verbose=F
)
gbmFit
#Neural Network
set.seed(42)
nnFit <- train(Species~ ., data = irisTrain,
method = "nnet",
trControl = fitControl,
maxit = 1000,
trace=FALSE
)
nnFit
```
#Compare the model performance
```{r}
models <- resamples(list(knn=knnFit,rf=rfFit,svm=svmFit,gbm=gbmFit,nn=nnFit))
summary(models)
```
#Predict the test dataset samples
```{r}
predictions<-predict(nnFit,irisTest)
confusionMatrix(predictions,irisTest$Species )
```