Statystyczna Eksploracja Danych
LABORATORIUM 5
DRZEWA: PODZIAŁ NA ARGUMENTACH
Dzisiejsze zajęcia dotyczą drzew klasyfikujących, omówionych podczas Wykładu 4
Najistotniejszym elemntem konstrukcji drzew jest reguła podziału, która decyduje o tym, w jaki sposób poszczególne elementy skupione w danym węźle zostaną przesunięte do swoich węzłów-dzieci. Ze względu na swoja prostotę obliczeniową będziemy korzystać z miary różnorodności danej wskaźnikiem Gini'ego. Przedstawia się ona w następujący sposób, przy czym \(p\) ułamkiem obserwacji należących do klasy 1
\(Q_{G} = 2 p (1 - p) \)
Rozpoczynamy od wylosowania danych z rozkładów Gaussa
library(MASS)
# Generowanie
draw.data.gauss <- function(S1, S2, m1, m2, n1, n2) {
X1 <- mvrnorm(n1, m1, S1)
X2 <- mvrnorm(n2, m2, S2)
X1 <- data.frame(X1); colnames(X1) <- c("x", "y")
X2 <- data.frame(X2); colnames(X2) <- c("x", "y")
X1$class <- 1; X2$class <- 2
data <- rbind(X1, X2); data$class <- factor(data$class)
return(data)
}
# Rysowanie punktów
plot.data <- function(data) {
cols <- c("blue", "orange")
plot(data[,1:2], col = cols[data$class], cex = 2)
text(data[,1:2], labels = 1:nrow(data), cex = 0.6)
}
# Parametry danych z rozkładu Gaussa
S1 <- matrix(c(4, 2, 2, 4), 2, 2)
S2 <- matrix(c(4, 2, 2, 2), 2, 2)
m1 <- c(-1, -1)
m2 <- c(2, 2)
n1 <- 30
n2 <- 20
# Ustawienie ziarna dla losowania danych
set.seed(128)
# Generowanie obserwacji
data <- draw.data.gauss(S1, S2, m1, m2, n1, n2)
# Rysowanie danych
plot.data(data)
przedstawiających się w następujący sposób
Rysunek 5.1
Teraz należy stworzyć cały "silnik" drzewa, czyli zdefiniowac współczynnik Gini'ego oraz funkcję do wyznaczania różnorodności, a dokładniej różnicy różnorodności pomiędzy rodzicem a dziećmi.
# Funkcja do współczynnika Gini'ego
gini <- function(tab) 2 * tab[1] * tab[2]
# Normalizacja prostego histogramu
norm.tab <- function(tab, tab.s) if(tab.s) { tab / tab.s } else { tab }
# Funkcja do wyznaczania różnorodności
get.Q <- function(data, name, threshold) {
# Rozkład klas w oryginalnej probie
tab.all <- table(data$class)
tab.all.s <- sum(tab.all)
# Rozkład klas dla warunku tj. po "lewej" stronie przedzialu
tab.left <- table(data$class[data[,name] <= threshold])
tab.left.s <- sum(tab.left)
# Rozkład klas po "prawej" stronie
tab.right <- tab.all - tab.left
tab.right.s <- tab.all.s - tab.left.s
# Normalizacja rozkładów
tab.all <- norm.tab(tab.all, tab.all.s)
tab.left <- norm.tab(tab.left, tab.left.s)
tab.right <- norm.tab(tab.right, tab.right.s)
# Wyznaczanie współczynników Gini'ego
# czyli różnorodności w węzłach
Q.all <- gini(tab.all)
Q.left <- gini(tab.left)
Q.right <- gini(tab.right)
# Ułamki elementów w węzłach
p.left <- tab.left.s / tab.all.s
p.right <- 1 - p.left
# Całkowita różnorodność w węzłach-dzieciach
Q.children <- p.left * Q.left + p.right * Q.right
# Zwracamy różnicę różnorodności
return(Q.all - Q.children)
}
W tym momencie nie przedstawia już trudności wyznaczenie różnicy różnorodności dla każdej wartości podziału zarówno argumentu \(x\) jak i \(y\) danych.
# Tworzenie wektorów podziałów
threshold.x <- sort(data$x)
threshold.y <- sort(data$y)
# Obliczanie różnic różnorodności
# dla podziałów na argumencie x oraz y
Q.x <- sapply(threshold.x, function(t) get.Q(data, "x", t))
Q.y <- sapply(threshold.y, function(t) get.Q(data, "y", t))
Wartości te można wykreślić i porównać ze sobą.
# Wartość różnorodności w funkcji progu podziału
plot(threshold.x, Q.x, t = "o", pch = 19, xlab = "prog podzialu (wartosc X lub Y)", ylab = "wartosc Q")
points(threshold.y, Q.y, t = "o", pch = 19, col = "blue")
Rysunek 5.2
Jak widać z wykresu, maksymalna wartość dla argumentu \(x\) jest większa niż dla \(y\), tak więc ona zostanie wybrana jako dzieląca zbiór danych, przy czym biorąc pod uwagę iż de facto mówimy o całym przedziale, wygodnie jest przyjąć wartość średnią z dwóch sąsiednich punktów
threshold.x.max <- (threshold.x[which.max(Q.x)] + threshold.x[which.max(Q.x) + 1]) / 2
abline(v = threshold.x.max, lty = 2, col = "red")
Rysunek 5.3
Można wreszcie zaznaczyć ostateczny podział na oryginalnym wykresie danych.
plot.data(data)
rect(-10, -10, threshold.x.max, 10, col = rgb(0, 0, 1, 0.2), border = NA)
Rysunek 5.4
DRZEWA: PODZIAŁ DANYCH
Korzystając z powyższych funkcji gini(), norm.tab() oraz get.Q(), możemy teraz zaproponować bardzo prostą funkcję dostosowaną do przypadku dwóch argumentów, która wyszuka najlepszy podział na argumentach oraz korzystając z tego rozdzieli dane.
split.data <- function(data) {
threshold.x <- sort(data$x)
threshold.y <- sort(data$y)
Q.x <- sapply(threshold.x, function(t) get.Q(data, "x", t))
Q.y <- sapply(threshold.y, function(t) get.Q(data, "y", t))
if(max(Q.x) > max(Q.y)) {
threshold.x.max <- (threshold.x[which.max(Q.x)] + threshold.x[which.max(Q.x) + 1]) / 2
data.left <- data[data$x <= threshold.x.max,]
data.right <- data[data$x > threshold.x.max,]
return(list(var = "x", var.val = threshold.x.max, data.left = data.left, data.right = data.right))
} else {
threshold.y.max <- (threshold.y[which.max(Q.y)] + threshold.y[which.max(Q.y) + 1]) / 2
data.left <- data[data$y <= threshold.y.max,]
data.right <- data[data$y > threshold.y.max,]
return(list(var = "y", var.val = threshold.y.max, data.left = data.left, data.right = data.right))
}
}
Daje to np. możliwość wizualizacji następnego podziału
blue.op <- rgb(0, 0, 1, 0.2)
orange.op <- rgb(1, 0.66, 0, 0.2)
plot.data(data)
s0 <- split.data(data)
rect(-10, -10, s0$var.val, 10, col = blue.op)
s1.r <- split.data(s0$data.right)
rect(s0$var.val, s1.r$var.val, 10, 10, col = orange.op)
Rysunek 5.5
oraz kolejnych
s2.r <- split.data(s1.r$data.right)
rect(s2.r$var.val, s1.r$var.val, 10, 10, col = orange.op)
rect(s0$var.val, s1.r$var.val, s2.r$var.val, 10, col = orange.op)
s2.l <- split.data(s1.r$data.left)
rect(s2.l$var.val, -10, 10, s1.r$var.val, col = orange.op)
rect(s0$var.val, -10, s2.l$var.val, s1.r$var.val, col = blue.op)
Rysunek 5.6
Warto przy tym zauważyć, że podziały w rogach prawym górnym oraz lewym górnym są ostateczne: w prostokątach pozostały jedynie elementy jednej klasy.
DRZEWA: PEŁNA FUNKCJA
Powyższe rozważania dają nam możliwość wykonania prymitywnej funkcji do implementacji drzewa klasyfikującego w przypadku dwóch argumentów. Funkcja działa rekurencyjnie i ma na celu dopowoadzenie do takiej sytuacji, aby w liściach znajdowały się jedynie elementy jednej klasy
# Rekurencyjna prymitywna funkcja implementująca drzewa
split.data <- function(data, sp) {
tab <- table(data$class)
if(gini(tab) < 1e-6) {
cat(sp, "Lisc drzewa, klasa A: ", tab[1], "klasa B:", tab[2], "\n")
} else {
threshold.x <- sort(data$x)
threshold.y <- sort(data$y)
Q.x <- sapply(threshold.x, function(t) get.Q(data, "x", t))
Q.y <- sapply(threshold.y, function(t) get.Q(data, "y", t))
if(max(Q.x) > max(Q.y)) {
threshold.x.max <- (threshold.x[which.max(Q.x)] + threshold.x[which.max(Q.x) + 1]) / 2
data.left <- data[data$x <= threshold.x.max,]
data.right <- data[data$x > threshold.x.max,]
cat(sp, "Podzial na X", threshold.x.max, "\n")
} else {
threshold.y.max <- (threshold.y[which.max(Q.y)] + threshold.y[which.max(Q.y) + 1]) / 2
data.left <- data[data$y <= threshold.y.max,]
data.right <- data[data$y > threshold.y.max,]
cat(sp, "Podzial na Y", threshold.y.max, "\n")
}
sp <- paste(sp," ")
split.data(data.left, sp)
split.data(data.right, sp)
}
}
dającej poniższy efekt
split.data(data, "")
Podzial na X 0.9428313
Podzial na Y 0.6573468
Lisc drzewa, klasa A: 18 klasa B: 0
Podzial na Y 1.535119
Podzial na Y 1.233252
Podzial na Y 0.8255293
Lisc drzewa, klasa A: 0 klasa B: 1
Lisc drzewa, klasa A: 1 klasa B: 0
Lisc drzewa, klasa A: 0 klasa B: 1
Lisc drzewa, klasa A: 6 klasa B: 0
Podzial na Y 0.765695
Podzial na X 3.061114
Podzial na Y -1.348321
Podzial na Y -1.927858
Lisc drzewa, klasa A: 1 klasa B: 0
Lisc drzewa, klasa A: 0 klasa B: 1
Lisc drzewa, klasa A: 3 klasa B: 0
Lisc drzewa, klasa A: 0 klasa B: 1
Podzial na X 2.227567
Podzial na X 1.930539
Lisc drzewa, klasa A: 0 klasa B: 5
Lisc drzewa, klasa A: 1 klasa B: 0
Lisc drzewa, klasa A: 0 klasa B: 11
PAKIET RPART
Oczywiście, celem powyższych rozważań i bardzo prostych funkcji jest jedynie ilustracja problemu. Do wykorzystywania drzew klasyfikujących w pakiecie R jest dedykowana biblioteka rpart, której główną funkcją jest imienniczka pakietu rpart(). Poniżej wywołamy ją z argumentami minsplit=1 oraz minbucket=1. Pierwszy z parametrów określa minimalną liczbę elementów, które muszą się znaleźć w weźle, aby dochodziło do podziału. Drugi - minimalną liczbę obserwacji, która musi znaleźć się w węźle, aby został uznany za liść. Wybrane przez nas parametry są bardzo "agresywne", czyli de facto będą raczej prowadzić do przetrenowania drzewa, ale pozwolą nam porównać nowe wyniki z poprzednio otrzymanymi.
# Biblioteka rpart
library(rpart)
# Wyuczenie drzewa na danych
tree
<- rpart(class
~ y
+ x, data,
minsplit = 1,
minbucket = 1)
n= 50
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 50 20 1 (0.60000000 0.40000000)
2) x< 0.9428313 27 2 1 (0.92592593 0.07407407)
4) y< 0.6573468 18 0 1 (1.00000000 0.00000000) *
5) y>=0.6573468 9 2 1 (0.77777778 0.22222222)
10) y>=1.535119 6 0 1 (1.00000000 0.00000000) *
11) y< 1.535119 3 1 2 (0.33333333 0.66666667)
22) y>=0.8255293 2 1 1 (0.50000000 0.50000000)
44) y< 1.233252 1 0 1 (1.00000000 0.00000000) *
45) y>=1.233252 1 0 2 (0.00000000 1.00000000) *
23) y< 0.8255293 1 0 2 (0.00000000 1.00000000) *
3) x>=0.9428313 23 5 2 (0.21739130 0.78260870)
6) y< 0.765695 6 2 1 (0.66666667 0.33333333)
12) x< 3.061114 5 1 1 (0.80000000 0.20000000)
24) y>=-1.348321 3 0 1 (1.00000000 0.00000000) *
25) y< -1.348321 2 1 1 (0.50000000 0.50000000)
50) y< -1.927858 1 0 1 (1.00000000 0.00000000) *
51) y>=-1.927858 1 0 2 (0.00000000 1.00000000) *
13) x>=3.061114 1 0 2 (0.00000000 1.00000000) *
7) y>=0.765695 17 1 2 (0.05882353 0.94117647)
14) x< 2.227567 6 1 2 (0.16666667 0.83333333)
28) x>=1.930539 1 0 1 (1.00000000 0.00000000) *
29) x< 1.930539 5 0 2 (0.00000000 1.00000000) *
15) x>=2.227567 11 0 2 (0.00000000 1.00000000) *
Bibioteka rpart umożliwia także stworzenie graficznej reprezentacji drzewa za pomocą przeciążonej funkcji plot, do której należy dodać odpowiednie opcje związane z tekstem.
# Rysowanie drzewa
plot(tree, uniform = TRUE)
text(tree, use.n = TRUE, all = TRUE, font = 2)
Rysunek 5.7
Trzeba jednak uczciwie przyznać, że podstawowy rysunek jest, delikatnie mówiąc, daleki od doskonałości. Z pomocą przychodzi biblioteka rpart.plot z dedykowaną funkcją rpart.plot(), umożliwiającą otrzymanie duuużo lepszego efektu...
# Biblioteka rpart.plot
library(rpart.plot)
# Rysowanie drzewa
rpart.plot(tree, type = 1, extra = 1)
Rysunek 5.8
DRZEWA: PRZEWIDYWANIE
Rzecz jasna, drzewa służą do przewidywania, tzn klasyfikacji nowych danych. Jeśli wykorzystamy nasze drzewo do powtórnej klasyfikacji, otrzymamy pełną zgodność.
# Przewidywanie
tree.class <- predict(tree, newdata = data, type = "class")
# Porównanie z oryginalnymi klasami
table(data$class, tree.class)
Jednak taki klasyfikator jest niewątpliwie "przeuczony". Wywołamy teraz funkcję rpart() dla domyślnych parametrów
# Domyślne wywołanie rpart
# minsplit = 20, minbucket = round(minsplit/3)
tree.def <- rpart(class ~ y + x, data)
rpart.plot(tree.def, type = 1, extra = 1)
Rysunek 5.9
i dokonamy przewidywania za jej pomocą
# Porównanie z oryginalnymi klasami
tree.class <- predict(tree.def, newdata = data, type = "class")
table(data$class, tree.class)
Jak widać, efekty nie są dużo gorsze. Sytuacja może się jeszcze bardziej zmienić w przypadku wykorzystania nowych danych.
# Generacja nowych danych
data.new <- draw.data.gauss(S1, S2, m1, m2, n1, n2)
# Przewidywanie za pomocą "przeuczonego" klasyfikatora
tree.class <- predict(tree, newdata = data.new, type = "class")
table(data.new$class, tree.class)
# Przewidywanie za pomocą "przyciętego" klasyfikatora
tree.class <- predict(tree.def, newdata = data.new, type = "class")
table(data.new$class, tree.class)