Statystyczna Eksploracja Danych

LABORATORIUM 5

DRZEWA: PODZIAŁ NA ARGUMENTACH

Dzisiejsze zajęcia dotyczą drzew klasyfikujących, omówionych podczas Wykładu 4

Najistotniejszym elemntem konstrukcji drzew jest reguła podziału, która decyduje o tym, w jaki sposób poszczególne elementy skupione w danym węźle zostaną przesunięte do swoich węzłów-dzieci. Ze względu na swoja prostotę obliczeniową będziemy korzystać z miary różnorodności danej wskaźnikiem Gini'ego. Przedstawia się ona w następujący sposób, przy czym \(p\) ułamkiem obserwacji należących do klasy 1

\(Q_{G} = 2 p (1 - p) \)

Rozpoczynamy od wylosowania danych z rozkładów Gaussa

library(MASS)

# Generowanie
draw.data.gauss <- function(S1, S2, m1, m2, n1, n2) {

X1 <- mvrnorm(n1, m1, S1)
X2 <- mvrnorm(n2, m2, S2)

X1 <- data.frame(X1); colnames(X1) <- c("x", "y")
X2 <- data.frame(X2); colnames(X2) <- c("x", "y")

X1$class <- 1; X2$class <- 2

data <- rbind(X1, X2); data$class <- factor(data$class)

return(data)
}

# Rysowanie punktów
plot.data <- function(data) {

cols <- c("blue", "orange")

plot(data[,1:2], col = cols[data$class], cex = 2)
text(data[,1:2], labels = 1:nrow(data), cex = 0.6)

}

# Parametry danych z rozkładu Gaussa
S1 <- matrix(c(4, 2, 2, 4), 2, 2)
S2 <- matrix(c(4, 2, 2, 2), 2, 2)

m1 <- c(-1, -1)
m2 <- c(2, 2)

n1 <- 30
n2 <- 20

# Ustawienie ziarna dla losowania danych
set.seed(128)

# Generowanie obserwacji
data <- draw.data.gauss(S1, S2, m1, m2, n1, n2)

# Rysowanie danych
plot.data(data)

przedstawiających się w następujący sposób

Rysunek 5.1

Teraz należy stworzyć cały "silnik" drzewa, czyli zdefiniowac współczynnik Gini'ego oraz funkcję do wyznaczania różnorodności, a dokładniej różnicy różnorodności pomiędzy rodzicem a dziećmi.

# Funkcja do współczynnika Gini'ego
gini <- function(tab) 2 * tab[1] * tab[2]

# Normalizacja prostego histogramu
norm.tab <- function(tab, tab.s) if(tab.s) { tab / tab.s } else { tab }

# Funkcja do wyznaczania różnorodności
get.Q <- function(data, name, threshold) {

# Rozkład klas w oryginalnej probie
tab.all <- table(data$class)
tab.all.s <- sum(tab.all)

# Rozkład klas dla warunku tj. po "lewej" stronie przedzialu
tab.left <- table(data$class[data[,name] <= threshold])
tab.left.s <- sum(tab.left)

# Rozkład klas po "prawej" stronie
tab.right <- tab.all - tab.left
tab.right.s <- tab.all.s - tab.left.s

# Normalizacja rozkładów
tab.all <- norm.tab(tab.all, tab.all.s)
tab.left <- norm.tab(tab.left, tab.left.s)
tab.right <- norm.tab(tab.right, tab.right.s)

# Wyznaczanie współczynników Gini'ego
# czyli różnorodności w węzłach
Q.all <- gini(tab.all)
Q.left <- gini(tab.left)
Q.right <- gini(tab.right)

# Ułamki elementów w węzłach
p.left <- tab.left.s / tab.all.s
p.right <- 1 - p.left

# Całkowita różnorodność w węzłach-dzieciach
Q.children <- p.left * Q.left + p.right * Q.right

# Zwracamy różnicę różnorodności
return(Q.all - Q.children)
}

W tym momencie nie przedstawia już trudności wyznaczenie różnicy różnorodności dla każdej wartości podziału zarówno argumentu \(x\) jak i \(y\) danych.

# Tworzenie wektorów podziałów
threshold.x <- sort(data$x)
threshold.y <- sort(data$y)

# Obliczanie różnic różnorodności
# dla podziałów na argumencie x oraz y
Q.x <- sapply(threshold.x, function(t) get.Q(data, "x", t))
Q.y <- sapply(threshold.y, function(t) get.Q(data, "y", t))

Wartości te można wykreślić i porównać ze sobą.

# Wartość różnorodności w funkcji progu podziału
plot(threshold.x, Q.x, t = "o", pch = 19, xlab = "prog podzialu (wartosc X lub Y)", ylab = "wartosc Q")
points(threshold.y, Q.y, t = "o", pch = 19, col = "blue")

Rysunek 5.2

Jak widać z wykresu, maksymalna wartość dla argumentu \(x\) jest większa niż dla \(y\), tak więc ona zostanie wybrana jako dzieląca zbiór danych, przy czym biorąc pod uwagę iż de facto mówimy o całym przedziale, wygodnie jest przyjąć wartość średnią z dwóch sąsiednich punktów

threshold.x.max <- (threshold.x[which.max(Q.x)] + threshold.x[which.max(Q.x) + 1]) / 2

abline(v = threshold.x.max, lty = 2, col = "red")

Rysunek 5.3

Można wreszcie zaznaczyć ostateczny podział na oryginalnym wykresie danych.

plot.data(data)
rect(-10, -10, threshold.x.max, 10, col = rgb(0, 0, 1, 0.2), border = NA)

Rysunek 5.4

DRZEWA: PODZIAŁ DANYCH

Korzystając z powyższych funkcji gini(), norm.tab() oraz get.Q(), możemy teraz zaproponować bardzo prostą funkcję dostosowaną do przypadku dwóch argumentów, która wyszuka najlepszy podział na argumentach oraz korzystając z tego rozdzieli dane.

split.data <- function(data) {

threshold.x <- sort(data$x)
threshold.y <- sort(data$y)

Q.x <- sapply(threshold.x, function(t) get.Q(data, "x", t))
Q.y <- sapply(threshold.y, function(t) get.Q(data, "y", t))

if(max(Q.x) > max(Q.y)) {

threshold.x.max <- (threshold.x[which.max(Q.x)] + threshold.x[which.max(Q.x) + 1]) / 2

data.left <- data[data$x <= threshold.x.max,]
data.right <- data[data$x > threshold.x.max,]

return(list(var = "x", var.val = threshold.x.max, data.left = data.left, data.right = data.right))
} else {

threshold.y.max <- (threshold.y[which.max(Q.y)] + threshold.y[which.max(Q.y) + 1]) / 2

data.left <- data[data$y <= threshold.y.max,]
data.right <- data[data$y > threshold.y.max,]

return(list(var = "y", var.val = threshold.y.max, data.left = data.left, data.right = data.right))
}
}

Daje to np. możliwość wizualizacji następnego podziału

blue.op <- rgb(0, 0, 1, 0.2)
orange.op <- rgb(1, 0.66, 0, 0.2)

plot.data(data)

s0 <- split.data(data)
rect(-10, -10, s0$var.val, 10, col = blue.op)

s1.r <- split.data(s0$data.right)
rect(s0$var.val, s1.r$var.val, 10, 10, col = orange.op)

Rysunek 5.5

oraz kolejnych

s2.r <- split.data(s1.r$data.right)
rect(s2.r$var.val, s1.r$var.val, 10, 10, col = orange.op) rect(s0$var.val, s1.r$var.val, s2.r$var.val, 10, col = orange.op)

s2.l <- split.data(s1.r$data.left)
rect(s2.l$var.val, -10, 10, s1.r$var.val, col = orange.op)
rect(s0$var.val, -10, s2.l$var.val, s1.r$var.val, col = blue.op)

Rysunek 5.6

Warto przy tym zauważyć, że podziały w rogach prawym górnym oraz lewym górnym są ostateczne: w prostokątach pozostały jedynie elementy jednej klasy.

DRZEWA: PEŁNA FUNKCJA

Powyższe rozważania dają nam możliwość wykonania prymitywnej funkcji do implementacji drzewa klasyfikującego w przypadku dwóch argumentów. Funkcja działa rekurencyjnie i ma na celu dopowoadzenie do takiej sytuacji, aby w liściach znajdowały się jedynie elementy jednej klasy

# Rekurencyjna prymitywna funkcja implementująca drzewa
split.data <- function(data, sp) {

tab <- table(data$class)

if(gini(tab) < 1e-6) {

cat(sp, "Lisc drzewa, klasa A: ", tab[1], "klasa B:", tab[2], "\n")
} else {

threshold.x <- sort(data$x)
threshold.y <- sort(data$y)

Q.x <- sapply(threshold.x, function(t) get.Q(data, "x", t))
Q.y <- sapply(threshold.y, function(t) get.Q(data, "y", t))

if(max(Q.x) > max(Q.y)) {

threshold.x.max <- (threshold.x[which.max(Q.x)] + threshold.x[which.max(Q.x) + 1]) / 2

data.left <- data[data$x <= threshold.x.max,]
data.right <- data[data$x > threshold.x.max,]

cat(sp, "Podzial na X", threshold.x.max, "\n")
} else {

threshold.y.max <- (threshold.y[which.max(Q.y)] + threshold.y[which.max(Q.y) + 1]) / 2

data.left <- data[data$y <= threshold.y.max,]
data.right <- data[data$y > threshold.y.max,]

cat(sp, "Podzial na Y", threshold.y.max, "\n")
}

sp <- paste(sp," ")

split.data(data.left, sp)
split.data(data.right, sp)
}
}

dającej poniższy efekt

split.data(data, "")
 Podzial na X 0.9428313 
   Podzial na Y 0.6573468 
     Lisc drzewa, klasa A:  18 klasa B: 0 
     Podzial na Y 1.535119 
       Podzial na Y 1.233252 
         Podzial na Y 0.8255293 
           Lisc drzewa, klasa A:  0 klasa B: 1 
           Lisc drzewa, klasa A:  1 klasa B: 0 
         Lisc drzewa, klasa A:  0 klasa B: 1 
       Lisc drzewa, klasa A:  6 klasa B: 0 
   Podzial na Y 0.765695 
     Podzial na X 3.061114 
       Podzial na Y -1.348321 
         Podzial na Y -1.927858 
           Lisc drzewa, klasa A:  1 klasa B: 0 
           Lisc drzewa, klasa A:  0 klasa B: 1 
         Lisc drzewa, klasa A:  3 klasa B: 0 
       Lisc drzewa, klasa A:  0 klasa B: 1 
     Podzial na X 2.227567 
       Podzial na X 1.930539 
         Lisc drzewa, klasa A:  0 klasa B: 5 
         Lisc drzewa, klasa A:  1 klasa B: 0 
       Lisc drzewa, klasa A:  0 klasa B: 11 

PAKIET RPART

Oczywiście, celem powyższych rozważań i bardzo prostych funkcji jest jedynie ilustracja problemu. Do wykorzystywania drzew klasyfikujących w pakiecie R jest dedykowana biblioteka rpart, której główną funkcją jest imienniczka pakietu rpart(). Poniżej wywołamy ją z argumentami minsplit=1 oraz minbucket=1. Pierwszy z parametrów określa minimalną liczbę elementów, które muszą się znaleźć w weźle, aby dochodziło do podziału. Drugi - minimalną liczbę obserwacji, która musi znaleźć się w węźle, aby został uznany za liść. Wybrane przez nas parametry są bardzo "agresywne", czyli de facto będą raczej prowadzić do przetrenowania drzewa, ale pozwolą nam porównać nowe wyniki z poprzednio otrzymanymi.

# Biblioteka rpart
library(rpart)

# Wyuczenie drzewa na danych
tree <- rpart(class ~ y + x, data, minsplit = 1, minbucket = 1)
n= 50 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 50 20 1 (0.60000000 0.40000000)  
   2) x< 0.9428313 27  2 1 (0.92592593 0.07407407)  
     4) y< 0.6573468 18  0 1 (1.00000000 0.00000000) *
     5) y>=0.6573468 9  2 1 (0.77777778 0.22222222)  
      10) y>=1.535119 6  0 1 (1.00000000 0.00000000) *
      11) y< 1.535119 3  1 2 (0.33333333 0.66666667)  
        22) y>=0.8255293 2  1 1 (0.50000000 0.50000000)  
          44) y< 1.233252 1  0 1 (1.00000000 0.00000000) *
          45) y>=1.233252 1  0 2 (0.00000000 1.00000000) *
        23) y< 0.8255293 1  0 2 (0.00000000 1.00000000) *
   3) x>=0.9428313 23  5 2 (0.21739130 0.78260870)  
     6) y< 0.765695 6  2 1 (0.66666667 0.33333333)  
      12) x< 3.061114 5  1 1 (0.80000000 0.20000000)  
        24) y>=-1.348321 3  0 1 (1.00000000 0.00000000) *
        25) y< -1.348321 2  1 1 (0.50000000 0.50000000)  
          50) y< -1.927858 1  0 1 (1.00000000 0.00000000) *
          51) y>=-1.927858 1  0 2 (0.00000000 1.00000000) *
      13) x>=3.061114 1  0 2 (0.00000000 1.00000000) *
     7) y>=0.765695 17  1 2 (0.05882353 0.94117647)  
      14) x< 2.227567 6  1 2 (0.16666667 0.83333333)  
        28) x>=1.930539 1  0 1 (1.00000000 0.00000000) *
        29) x< 1.930539 5  0 2 (0.00000000 1.00000000) *
      15) x>=2.227567 11  0 2 (0.00000000 1.00000000) *

Bibioteka rpart umożliwia także stworzenie graficznej reprezentacji drzewa za pomocą przeciążonej funkcji plot, do której należy dodać odpowiednie opcje związane z tekstem.

# Rysowanie drzewa
plot(tree, uniform = TRUE)
text(tree, use.n = TRUE, all = TRUE, font = 2)

Rysunek 5.7

Trzeba jednak uczciwie przyznać, że podstawowy rysunek jest, delikatnie mówiąc, daleki od doskonałości. Z pomocą przychodzi biblioteka rpart.plot z dedykowaną funkcją rpart.plot(), umożliwiającą otrzymanie duuużo lepszego efektu...

# Biblioteka rpart.plot
library(rpart.plot)

# Rysowanie drzewa
rpart.plot(tree, type = 1, extra = 1)

Rysunek 5.8

DRZEWA: PRZEWIDYWANIE

Rzecz jasna, drzewa służą do przewidywania, tzn klasyfikacji nowych danych. Jeśli wykorzystamy nasze drzewo do powtórnej klasyfikacji, otrzymamy pełną zgodność.

# Przewidywanie
tree.class <- predict(tree, newdata = data, type = "class")

# Porównanie z oryginalnymi klasami
table(data$class, tree.class)

Jednak taki klasyfikator jest niewątpliwie "przeuczony". Wywołamy teraz funkcję rpart() dla domyślnych parametrów

# Domyślne wywołanie rpart
# minsplit = 20, minbucket = round(minsplit/3)
tree.def <- rpart(class ~ y + x, data)
rpart.plot(tree.def, type = 1, extra = 1)

Rysunek 5.9

i dokonamy przewidywania za jej pomocą

# Porównanie z oryginalnymi klasami
tree.class <- predict(tree.def, newdata = data, type = "class")
table(data$class, tree.class)

Jak widać, efekty nie są dużo gorsze. Sytuacja może się jeszcze bardziej zmienić w przypadku wykorzystania nowych danych.

# Generacja nowych danych
data.new <- draw.data.gauss(S1, S2, m1, m2, n1, n2)

# Przewidywanie za pomocą "przeuczonego" klasyfikatora
tree.class <- predict(tree, newdata = data.new, type = "class")
table(data.new$class, tree.class)

# Przewidywanie za pomocą "przyciętego" klasyfikatora
tree.class <- predict(tree.def, newdata = data.new, type = "class")
table(data.new$class, tree.class)