scikit-learn 에서 제공되는 decision tree 를 학습하면 각 branching 과정에 대한 정보가 모델에 저장됩니다. 이를 이용하면 tree traversal visualization 을 하거나, parameters 를 저장하여 직접 decsion rules based classifier 를 만들 수 있습니다. 이번 포스트에서는 학습된 decision tree 의 parameters 를 이용하는 방법을 소개합니다.
Brief review of decision tree
의사결정나무는 데이터의 공간을 직사각형으로 나눠가며 최대한 같은 종류의 데이터로 이뤄진 부분공간을 찾아가는 classifiers 입니다. 마치 clustering 처럼 비슷한 공간을 하나의 leaf node 로 나눠갑니다. 아래의 데이터를 분류하는 decision tree 를 학습한다고 가정합니다.
Decision tree 는 매번 각 변수에서 적절한 기준선을 찾아가며 공간을 이분 (bisect)합니다. 이 과정을 조건이 만족할 때까지 반복합니다.
위 그림처럼 tree 가 학습된 뒤, 우리는 각 공간에 대한 bisection path 를 얻을 수 있습니다. #0 공간인 root 는 #1 과 #2 로 나눠지고, #2 는 #3 과 #4 로 나눠집니다.
위 그림처럼 학습된 tree 의 각 decision rules 와 (parent, children) 의 관계는 아래와 같습니다.
section no | # red | # blue | # entropy | # decision | # left child | # right child |
---|---|---|---|---|---|---|
0 | 9 | 12 | 0.297 | 1 | 2 | |
1 | 3 | 6 | 0.276 | 5 | 6 | |
2 | 6 | 6 | 0.301 | 3 | 4 | |
3 | 6 | 2 | 0.244 | 7 | 8 | |
4 | 0 | 4 | 0 | - | - | - |
5 | 3 | 0 | 0 | - | - | - |
6 | 0 | 6 | 0 | - | - | - |
7 | 4 | 0 | 0 | - | - | - |
8 | 2 | 2 | 0.301 | 9 | 10 | |
9 | 0 | 2 | 0 | - | - | - |
10 | 2 | 0 | 0 | - | - | - |
이 정보를 이용하면 decision tree rules 를 시각화 하거나, 학습된 tree 의 parameters 를 다른 모델에 이식할 수 있습니다. 이번 포스트에서는 위 표의 정보들을 이용하여 text 로 된 tree traversal visualization 을 수행합니다.
Dataset
우리는 이전 포스트에서 만든 인공데이터를 이용하여 decision tree 가 학습한 tree path 의 parameters 를 이용하여 tree traversal 을 하겠습니다.
우리가 이용할 데이터는 아래처럼 만들 수 있습니다. 데이터 생성 함수는 이전 포스트를 참고하세요
이를 학습하면 아래와 학습 과정을 얻을 수 있습니다. Decision tree 는 한 마디에서 하나의 변수만을 이용하기 때문에 사선의 경계면을 계단 형식의 경계선으로 학습합니다. Depth = 5 까지는 사각형의 모습을 하는데 이용되며, 그 이후 depth = 10 까지는 사선 방향의 경계면을 학습하는데 이용됩니다.
Parameters
우리의 목표는 brief review 에서의 표를 그릴 수 있는 정보를 얻는 것입니다.
각 마디의 children 구조는 tree_.children_left, tree_.children_right, tree_.threshold, tree_.feature 에 저장되어 있습니다. 마디의 개수는 총 59 개 입니다.
각 마디의 idx 는 만들어진 순서대로입니다. Root node 는 left, right, threshold 등의 0 번째 입니다. left[0] = 1, right[0] = 6 은 root node 의 left child 은 1 번, right child 은 6 번 마디라는 의미입니다.
left 와 right children 을 살펴보면 - 값이 있습니다. 해당 마디가 leaf node 일 때 negative index 를 지닙니다.
비슷하게 features 도 negative value 를 지닙니다. 우리가 이용한 synthetic data 는 두 개의 변수로 이뤄져 있습니다. features[0] = 1 은 첫번째 decision 에 중 을 이용하였다는 의미입니다. index=2 는 leaf node 이기 때문에 decision 을 하지 않습니다. 그렇기 때문에 features[2] 에 negative index -2 가 저장되어 있습니다.
threshold 는 각 마디에서 이용된 threshold 입니다. features[0] 와 threshold[0] 를 합쳐 해석하면 is ? 입니다. Threshold 는 negative value 를 가질 수 있기 때문에, leaf nodes 를 확인하기 위해서는 children 이나 features 의 indices 를 살펴봐야 합니다.
때로는 features 에 이름이 있기도 합니다. Visualization 을 위해서는 names 로 features 를 살펴봐도 좋습니다.
각 마디에 포함되는 samples 의 개수는 value 에 저장되어 있습니다. 59 개 마디에 대하여 (1, n_classes) 의 array 형식입니다. value[0] 은 class 0 이 1045 개, class 1 이 961 개 포함되었다는 의미입니다. index=0 은 root node 이기 때문에 데이터셋 전체에 포함된 class 0, 1 의 개수가 각각 1045, 961 개 라는 의미입니다.
이를 이용하여 각 마디의 label 과 각 class 에 속할 확률을 만들 수 있습니다.
Tree traversal
이제 우리가 이용할 parameters 의 구조와 위치는 모두 파악했습니다. 이를 이용하여 tree traversal 을 수행합니다. node 가 leaf 인지 확인하는 함수를 만듭니다.
각 마디마다 children 이 있을 경우, right, left 순서로 children 에 관한 정보 (idx, depth, equation) 를 list 로 만듭니다. 이를 stack 에 쌓은 뒤 pop 을 할 것이기 때문에 right, left 순서로 만듭니다. Equation 은 와 같은 decision rule 입니다.
Print 를 할 때에는 depth 만큼의 indention 을 넣습니다. 각 마디의 label 과 데이터의 개수, 각 클래스에 속할 확률을 함께 출력햡니다.
Traversal 은 처음 root node 에 대하여 items 을 만든 뒤, stack 에 다른 마디가 없을 때까지 while loop 을 반복합니다.
Root node 의 children 을 이용하여 만든 두 개의 items 가 stack 에 포함되어 있습니다.
첫 번째 item 을 pop() 하면 left child 의 idx, depth, equation 이 return 됩니다.
이번 마디가 leaf node 이면 마디의 상태를 출력하고, children 을 지닌 branch 이면 자신의 상태를 출력한 뒤, 자신의 children 을 stack 에 쌓습니다. 이를 통하여 depth-first search 를 할 수 있습니다.
이 과정을 정리하면 아래와 같습니다.
_print_tree_traversal() 함수는 decision tree 의 parameters 를 각각 입력받는 함수입니다. 학습된 decision tree 를 입력하면 이용가능한 형태의 parameters 를 만드는 함수를 하나 더 만들어줍니다. Features names 를 입력하면 이를 이용하고, 그렇지 않다면 처럼 변수 이름을 붙여줍니다.
print_tree_traversal() 를 실행한 결과입니다.
위 과정에 언급된 parameters 만을 저장하면 scikit-learn 을 이용하여 학습된 decision tree classifier 를 다른 언어로 구현할 수 있습니다.
Root n_samples=2006, prob=(0.521, 0.479)
|--- (x1 < 4.005). label=1 n_samples=824, prob=(0.220, 0.780)
|--- |--- (x0 < 7.005). label=1 n_samples=612, prob=(0.000, 1.000)
|--- |--- (x0 > 7.005). label=0 n_samples=212, prob=(0.854, 0.146)
|--- |--- |--- (x1 < 3.494). label=0 n_samples=181, prob=(1.000, 0.000)
|--- |--- |--- (x1 > 3.494). label=1 n_samples=31, prob=(0.000, 1.000)
|--- (x1 > 4.005). label=0 n_samples=1182, prob=(0.731, 0.269)
|--- |--- (x0 < 2.002). label=0 n_samples=229, prob=(1.000, 0.000)
|--- |--- (x0 > 2.002). label=0 n_samples=953, prob=(0.666, 0.334)
|--- |--- |--- (x1 < 6.986). label=0 n_samples=474, prob=(0.511, 0.489)
|--- |--- |--- |--- (x0 < 4.367). label=1 n_samples=135, prob=(0.141, 0.859)
|--- |--- |--- |--- |--- (x1 < 4.987). label=0 n_samples=31, prob=(0.613, 0.387)
|--- |--- |--- |--- |--- |--- (x0 < 3.280). label=1 n_samples=17, prob=(0.294, 0.706)
|--- |--- |--- |--- |--- |--- |--- (x1 < 4.268). label=0 n_samples=5, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x1 > 4.268). label=1 n_samples=12, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- (x0 > 3.280). label=0 n_samples=14, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- (x1 > 4.987). label=1 n_samples=104, prob=(0.000, 1.000)
|--- |--- |--- |--- (x0 > 4.367). label=0 n_samples=339, prob=(0.658, 0.342)
|--- |--- |--- |--- |--- (x1 < 4.854). label=1 n_samples=108, prob=(0.463, 0.537)
|--- |--- |--- |--- |--- |--- (x0 < 7.307). label=0 n_samples=50, prob=(0.960, 0.040)
|--- |--- |--- |--- |--- |--- |--- (x0 < 6.998). label=0 n_samples=39, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x0 > 6.998). label=0 n_samples=11, prob=(0.818, 0.182)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 < 4.682). label=0 n_samples=10, prob=(0.900, 0.100)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.009). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.009). label=0 n_samples=9, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 > 4.682). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- (x0 > 7.307). label=1 n_samples=58, prob=(0.034, 0.966)
|--- |--- |--- |--- |--- |--- |--- (x0 < 7.481). label=1 n_samples=7, prob=(0.286, 0.714)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.385). label=1 n_samples=4, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.385). label=0 n_samples=3, prob=(0.667, 0.333)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 < 4.616). label=0 n_samples=2, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 > 4.616). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- (x0 > 7.481). label=1 n_samples=51, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- (x1 > 4.854). label=0 n_samples=231, prob=(0.749, 0.251)
|--- |--- |--- |--- |--- |--- (x0 < 5.665). label=1 n_samples=57, prob=(0.333, 0.667)
|--- |--- |--- |--- |--- |--- |--- (x1 < 5.627). label=0 n_samples=20, prob=(0.950, 0.050)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 < 4.480). label=0 n_samples=2, prob=(0.500, 0.500)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 < 5.111). label=0 n_samples=1, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 > 5.111). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 > 4.480). label=0 n_samples=18, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x1 > 5.627). label=1 n_samples=37, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- (x0 > 5.665). label=0 n_samples=174, prob=(0.885, 0.115)
|--- |--- |--- |--- |--- |--- |--- (x1 < 6.490). label=0 n_samples=129, prob=(0.953, 0.047)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 < 5.003). label=0 n_samples=15, prob=(0.600, 0.400)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.361). label=0 n_samples=9, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.361). label=1 n_samples=6, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 > 5.003). label=0 n_samples=114, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x1 > 6.490). label=0 n_samples=45, prob=(0.689, 0.311)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.062). label=1 n_samples=14, prob=(0.071, 0.929)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 6.638). label=1 n_samples=11, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 6.638). label=1 n_samples=3, prob=(0.333, 0.667)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.062). label=0 n_samples=31, prob=(0.968, 0.032)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.276). label=0 n_samples=3, prob=(0.667, 0.333)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.276). label=0 n_samples=28, prob=(1.000, 0.000)
|--- |--- |--- (x1 > 6.986). label=0 n_samples=479, prob=(0.820, 0.180)
|--- |--- |--- |--- (x0 < 8.003). label=0 n_samples=349, prob=(1.000, 0.000)
|--- |--- |--- |--- (x0 > 8.003). label=1 n_samples=130, prob=(0.338, 0.662)
|--- |--- |--- |--- |--- (x1 < 8.019). label=0 n_samples=44, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- (x1 > 8.019). label=1 n_samples=86, prob=(0.000, 1.000)
위 결과에서 살펴볼 수 있듯이 처음 를 통하여 아래 부분은 큰 어려움없이 분류가 잘 됩니다. 대부분의 depth 를 삼각형 모양의 경계면을 학습하는데 이용하고 있음을 확인할 수 있습니다.