scikit-learn 에서 제공되는 decision tree 를 학습하면 각 branching 과정에 대한 정보가 모델에 저장됩니다. 이를 이용하면 tree traversal visualization 을 하거나, parameters 를 저장하여 직접 decsion rules based classifier 를 만들 수 있습니다. 이번 포스트에서는 학습된 decision tree 의 parameters 를 이용하는 방법을 소개합니다.
Brief review of decision tree
의사결정나무는 데이터의 공간을 직사각형으로 나눠가며 최대한 같은 종류의 데이터로 이뤄진 부분공간을 찾아가는 classifiers 입니다. 마치 clustering 처럼 비슷한 공간을 하나의 leaf node 로 나눠갑니다. 아래의 데이터를 분류하는 decision tree 를 학습한다고 가정합니다.
Decision tree 는 매번 각 변수에서 적절한 기준선을 찾아가며 공간을 이분 (bisect)합니다. 이 과정을 조건이 만족할 때까지 반복합니다.
위 그림처럼 tree 가 학습된 뒤, 우리는 각 공간에 대한 bisection path 를 얻을 수 있습니다. #0 공간인 root 는 #1 과 #2 로 나눠지고, #2 는 #3 과 #4 로 나눠집니다.
위 그림처럼 학습된 tree 의 각 decision rules 와 (parent, children) 의 관계는 아래와 같습니다.
section no | # red | # blue | # entropy | # decision | # left child | # right child |
---|---|---|---|---|---|---|
0 | 9 | 12 | 0.297 | 1 | 2 | |
1 | 3 | 6 | 0.276 | 5 | 6 | |
2 | 6 | 6 | 0.301 | 3 | 4 | |
3 | 6 | 2 | 0.244 | 7 | 8 | |
4 | 0 | 4 | 0 | - | - | - |
5 | 3 | 0 | 0 | - | - | - |
6 | 0 | 6 | 0 | - | - | - |
7 | 4 | 0 | 0 | - | - | - |
8 | 2 | 2 | 0.301 | 9 | 10 | |
9 | 0 | 2 | 0 | - | - | - |
10 | 2 | 0 | 0 | - | - | - |
이 정보를 이용하면 decision tree rules 를 시각화 하거나, 학습된 tree 의 parameters 를 다른 모델에 이식할 수 있습니다. 이번 포스트에서는 위 표의 정보들을 이용하여 text 로 된 tree traversal visualization 을 수행합니다.
Dataset
우리는 이전 포스트에서 만든 인공데이터를 이용하여 decision tree 가 학습한 tree path 의 parameters 를 이용하여 tree traversal 을 하겠습니다.
우리가 이용할 데이터는 아래처럼 만들 수 있습니다. 데이터 생성 함수는 이전 포스트를 참고하세요
from soydata.visualize import ipython_2d_scatter
from soydata.data import get_decision_tree_data_2
X_2, y_2 = get_decision_tree_data_2(n_samples=2000)
ipython_2d_scatter(X_2, y_2, marker_size=5, height=1000, width=1000, title='Synthetic dataset 2')
이를 학습하면 아래와 학습 과정을 얻을 수 있습니다. Decision tree 는 한 마디에서 하나의 변수만을 이용하기 때문에 사선의 경계면을 계단 형식의 경계선으로 학습합니다. Depth = 5 까지는 사각형의 모습을 하는데 이용되며, 그 이후 depth = 10 까지는 사선 방향의 경계면을 학습하는데 이용됩니다.
Parameters
우리의 목표는 brief review 에서의 표를 그릴 수 있는 정보를 얻는 것입니다.
각 마디의 children 구조는 tree_.children_left, tree_.children_right, tree_.threshold, tree_.feature 에 저장되어 있습니다. 마디의 개수는 총 59 개 입니다.
left = dt.tree_.children_left
right = dt.tree_.children_right
threshold = dt.tree_.threshold
features = dt.tree_.feature
print(left.shape) # (59,)
print(right.shape) # (59,)
print(threshold.shape) # (59,)
print(features.shape) # (59,)
각 마디의 idx 는 만들어진 순서대로입니다. Root node 는 left, right, threshold 등의 0 번째 입니다. left[0] = 1, right[0] = 6 은 root node 의 left child 은 1 번, right child 은 6 번 마디라는 의미입니다.
left 와 right children 을 살펴보면 - 값이 있습니다. 해당 마디가 leaf node 일 때 negative index 를 지닙니다.
print(left)
# array([ 1, 2, -1, 4, -1, -1, 7, -1, 9, 10, 11, 12, 13, -1, -1, -1, -1,
# 18, 19, 20, -1, 22, 23, -1, -1, -1, 27, 28, -1, 30, -1, -1, -1, 34,
# 35, 36, 37, -1, -1, -1, -1, 42, 43, 44, -1, -1, -1, 48, 49, -1, -1,
# 52, -1, -1, 55, -1, 57, -1, -1])
print(right)
# array([ 6, 3, -1, 5, -1, -1, 8, -1, 54, 17, 16, 15, 14, -1, -1, -1, -1,
# 33, 26, 21, -1, 25, 24, -1, -1, -1, 32, 29, -1, 31, -1, -1, -1, 41,
# 40, 39, 38, -1, -1, -1, -1, 47, 46, 45, -1, -1, -1, 51, 50, -1, -1,
# 53, -1, -1, 56, -1, 58, -1, -1])
비슷하게 features 도 negative value 를 지닙니다. 우리가 이용한 synthetic data 는 두 개의 변수로 이뤄져 있습니다. features[0] = 1 은 첫번째 decision 에 중 을 이용하였다는 의미입니다. index=2 는 leaf node 이기 때문에 decision 을 하지 않습니다. 그렇기 때문에 features[2] 에 negative index -2 가 저장되어 있습니다.
threshold 는 각 마디에서 이용된 threshold 입니다. features[0] 와 threshold[0] 를 합쳐 해석하면 is ? 입니다. Threshold 는 negative value 를 가질 수 있기 때문에, leaf nodes 를 확인하기 위해서는 children 이나 features 의 indices 를 살펴봐야 합니다.
print(features)
# array([ 1, 0, -2, 1, -2, -2, 0, -2, 1, 0, 1, 0, 1, -2, -2, -2, -2,
# 1, 0, 0, -2, 1, 0, -2, -2, -2, 0, 0, -2, 1, -2, -2, -2, 0,
# 1, 0, 1, -2, -2, -2, -2, 1, 1, 0, -2, -2, -2, 0, 0, -2, -2,
# 0, -2, -2, 0, -2, 1, -2, -2])
print(threshold)
# array([ 4.00539017, 7.00503731, -2. , 3.49398017, -2. ,
# -2. , 2.00174618, -2. , 6.98615742, 4.36656189,
# 4.98681259, 3.27961016, 4.26769447, -2. , -2. ,
# -2. , -2. , 4.8539381 , 7.30683231, 6.99763584,
# -2. , 4.68205166, 7.00949955, -2. , -2. ,
# -2. , 7.4809618 , 7.38458014, -2. , 4.61641979,
# -2. , -2. , -2. , 5.66478252, 5.6269908 ,
# 4.47982836, 5.11054134, -2. , -2. , -2. ,
# -2. , 6.48954964, 5.00303268, 7.36141872, -2. ,
# -2. , -2. , 7.06216717, 6.63848925, -2. ,
# -2. , 7.27605534, -2. , -2. , 8.00331211,
# -2. , 8.01910591, -2. , -2. ])
때로는 features 에 이름이 있기도 합니다. Visualization 을 위해서는 names 로 features 를 살펴봐도 좋습니다.
features = ['x{}'.format(feature) if feature >= 0 else None for feature in dt.tree_.feature]
print(features)
# ['x1', 'x0', None, 'x1', None, None, 'x0', None, 'x1', 'x0',
# 'x1', 'x0', 'x1', None, None, None, None, 'x1', 'x0', 'x0',
# None, 'x1', 'x0', None, None, None, 'x0', 'x0', None, 'x1',
# None, None, None, 'x0', 'x1', 'x0', 'x1', None, None, None,
# None, 'x1', 'x1', 'x0', None, None, None, 'x0', 'x0', None,
# None, 'x0', None, None, 'x0', None, 'x1', None, None]
각 마디에 포함되는 samples 의 개수는 value 에 저장되어 있습니다. 59 개 마디에 대하여 (1, n_classes) 의 array 형식입니다. value[0] 은 class 0 이 1045 개, class 1 이 961 개 포함되었다는 의미입니다. index=0 은 root node 이기 때문에 데이터셋 전체에 포함된 class 0, 1 의 개수가 각각 1045, 961 개 라는 의미입니다.
print(dt.tree_.value.shape)
# (59, 1, 2)
print(dt.tree_.value[0])
# array([[1045., 961.]])
이를 이용하여 각 마디의 label 과 각 class 에 속할 확률을 만들 수 있습니다.
prob = np.asarray([freq/freq.sum() for freq in dt.tree_.value])
print(prob[0])
# array([[0.52093719, 0.47906281]])
labels = np.asarray([prob_.argmax() for prob_ in prob])
print(labels)
# array([0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0,
# 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0,
# 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1])
Tree traversal
이제 우리가 이용할 parameters 의 구조와 위치는 모두 파악했습니다. 이를 이용하여 tree traversal 을 수행합니다. node 가 leaf 인지 확인하는 함수를 만듭니다.
def is_leaf(i):
return features[i] is None
각 마디마다 children 이 있을 경우, right, left 순서로 children 에 관한 정보 (idx, depth, equation) 를 list 로 만듭니다. 이를 stack 에 쌓은 뒤 pop 을 할 것이기 때문에 right, left 순서로 만듭니다. Equation 은 와 같은 decision rule 입니다.
def make_stack_item(idx, depth):
# (child idx, depth, equation)
items = [
(right[idx], depth, '{} > {}'.format(features[idx], '%.3f'%threshold[idx])),
(left[idx], depth, '{} < {}'.format(features[idx], '%.3f'%threshold[idx]))
]
return items
Print 를 할 때에는 depth 만큼의 indention 을 넣습니다. 각 마디의 label 과 데이터의 개수, 각 클래스에 속할 확률을 함께 출력햡니다.
def print_status(i, depth, equation):
message = '{} ({}). label={} n_samples={}, prob=({})'.format(
'|--- ' * depth, # indention
equation, # equation
labels[i], # label
size[i], # n samples
', '.join(['%.3f' % float(p) for p in prob[i][0]])) # prob
print(message, flush=True)
Traversal 은 처음 root node 에 대하여 items 을 만든 뒤, stack 에 다른 마디가 없을 때까지 while loop 을 반복합니다.
Root node 의 children 을 이용하여 만든 두 개의 items 가 stack 에 포함되어 있습니다.
# initialize
stack = make_stack_item(idx=0, depth=1)
첫 번째 item 을 pop() 하면 left child 의 idx, depth, equation 이 return 됩니다.
while stack:
idx, depth, equation = stack.pop()
이번 마디가 leaf node 이면 마디의 상태를 출력하고, children 을 지닌 branch 이면 자신의 상태를 출력한 뒤, 자신의 children 을 stack 에 쌓습니다. 이를 통하여 depth-first search 를 할 수 있습니다.
# if node is leaf print status
if is_leaf(idx):
print_status(idx, depth, equation)
# else print status and add children (left, right) order
else:
print_status(idx, depth, equation)
stack += make_stack_item(idx, depth+1)
이 과정을 정리하면 아래와 같습니다.
def _print_tree_traversal(left, right, features, threshold, labels, size, prob):
# initialize
stack = make_stack_item(idx=0, depth=1)
# print root
print('Root n_samples={}, prob=({})'.format(
size[0], ', '.join(['%.3f' % float(p) for p in prob[0][0]])))
# while stack is not empty
while stack:
idx, depth, equation = stack.pop()
# if node is leaf print status
if is_leaf(idx):
print_status(idx, depth, equation)
# else print status and add children (left, right) order
else:
print_status(idx, depth, equation)
stack += make_stack_item(idx, depth+1)
_print_tree_traversal() 함수는 decision tree 의 parameters 를 각각 입력받는 함수입니다. 학습된 decision tree 를 입력하면 이용가능한 형태의 parameters 를 만드는 함수를 하나 더 만들어줍니다. Features names 를 입력하면 이를 이용하고, 그렇지 않다면 처럼 변수 이름을 붙여줍니다.
def print_tree_traversal(dt, feature_names=None):
left = dt.tree_.children_left
right = dt.tree_.children_right
threshold = dt.tree_.threshold
if feature_names:
features = [feature_names[f] if f >= 0 else None for f in dt.tree_.feature]
else:
features = ['x{}'.format(f) if f >= 0 else None for f in dt.tree_.feature]
size = np.asarray([freq.sum() for freq in dt.tree_.value], dtype=np.int)
prob = np.asarray([freq/freq.sum() for freq in dt.tree_.value])
labels = np.asarray([prob_.argmax() for prob_ in prob])
_print_tree_traversal(left, right, features, threshold, labels, size, prob)
print_tree_traversal() 를 실행한 결과입니다.
print_tree_traversal(dt)
위 과정에 언급된 parameters 만을 저장하면 scikit-learn 을 이용하여 학습된 decision tree classifier 를 다른 언어로 구현할 수 있습니다.
Root n_samples=2006, prob=(0.521, 0.479)
|--- (x1 < 4.005). label=1 n_samples=824, prob=(0.220, 0.780)
|--- |--- (x0 < 7.005). label=1 n_samples=612, prob=(0.000, 1.000)
|--- |--- (x0 > 7.005). label=0 n_samples=212, prob=(0.854, 0.146)
|--- |--- |--- (x1 < 3.494). label=0 n_samples=181, prob=(1.000, 0.000)
|--- |--- |--- (x1 > 3.494). label=1 n_samples=31, prob=(0.000, 1.000)
|--- (x1 > 4.005). label=0 n_samples=1182, prob=(0.731, 0.269)
|--- |--- (x0 < 2.002). label=0 n_samples=229, prob=(1.000, 0.000)
|--- |--- (x0 > 2.002). label=0 n_samples=953, prob=(0.666, 0.334)
|--- |--- |--- (x1 < 6.986). label=0 n_samples=474, prob=(0.511, 0.489)
|--- |--- |--- |--- (x0 < 4.367). label=1 n_samples=135, prob=(0.141, 0.859)
|--- |--- |--- |--- |--- (x1 < 4.987). label=0 n_samples=31, prob=(0.613, 0.387)
|--- |--- |--- |--- |--- |--- (x0 < 3.280). label=1 n_samples=17, prob=(0.294, 0.706)
|--- |--- |--- |--- |--- |--- |--- (x1 < 4.268). label=0 n_samples=5, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x1 > 4.268). label=1 n_samples=12, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- (x0 > 3.280). label=0 n_samples=14, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- (x1 > 4.987). label=1 n_samples=104, prob=(0.000, 1.000)
|--- |--- |--- |--- (x0 > 4.367). label=0 n_samples=339, prob=(0.658, 0.342)
|--- |--- |--- |--- |--- (x1 < 4.854). label=1 n_samples=108, prob=(0.463, 0.537)
|--- |--- |--- |--- |--- |--- (x0 < 7.307). label=0 n_samples=50, prob=(0.960, 0.040)
|--- |--- |--- |--- |--- |--- |--- (x0 < 6.998). label=0 n_samples=39, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x0 > 6.998). label=0 n_samples=11, prob=(0.818, 0.182)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 < 4.682). label=0 n_samples=10, prob=(0.900, 0.100)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.009). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.009). label=0 n_samples=9, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 > 4.682). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- (x0 > 7.307). label=1 n_samples=58, prob=(0.034, 0.966)
|--- |--- |--- |--- |--- |--- |--- (x0 < 7.481). label=1 n_samples=7, prob=(0.286, 0.714)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.385). label=1 n_samples=4, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.385). label=0 n_samples=3, prob=(0.667, 0.333)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 < 4.616). label=0 n_samples=2, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 > 4.616). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- (x0 > 7.481). label=1 n_samples=51, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- (x1 > 4.854). label=0 n_samples=231, prob=(0.749, 0.251)
|--- |--- |--- |--- |--- |--- (x0 < 5.665). label=1 n_samples=57, prob=(0.333, 0.667)
|--- |--- |--- |--- |--- |--- |--- (x1 < 5.627). label=0 n_samples=20, prob=(0.950, 0.050)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 < 4.480). label=0 n_samples=2, prob=(0.500, 0.500)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 < 5.111). label=0 n_samples=1, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x1 > 5.111). label=1 n_samples=1, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 > 4.480). label=0 n_samples=18, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x1 > 5.627). label=1 n_samples=37, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- (x0 > 5.665). label=0 n_samples=174, prob=(0.885, 0.115)
|--- |--- |--- |--- |--- |--- |--- (x1 < 6.490). label=0 n_samples=129, prob=(0.953, 0.047)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 < 5.003). label=0 n_samples=15, prob=(0.600, 0.400)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.361). label=0 n_samples=9, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.361). label=1 n_samples=6, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- (x1 > 5.003). label=0 n_samples=114, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- |--- |--- (x1 > 6.490). label=0 n_samples=45, prob=(0.689, 0.311)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.062). label=1 n_samples=14, prob=(0.071, 0.929)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 6.638). label=1 n_samples=11, prob=(0.000, 1.000)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 6.638). label=1 n_samples=3, prob=(0.333, 0.667)
|--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.062). label=0 n_samples=31, prob=(0.968, 0.032)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 < 7.276). label=0 n_samples=3, prob=(0.667, 0.333)
|--- |--- |--- |--- |--- |--- |--- |--- |--- (x0 > 7.276). label=0 n_samples=28, prob=(1.000, 0.000)
|--- |--- |--- (x1 > 6.986). label=0 n_samples=479, prob=(0.820, 0.180)
|--- |--- |--- |--- (x0 < 8.003). label=0 n_samples=349, prob=(1.000, 0.000)
|--- |--- |--- |--- (x0 > 8.003). label=1 n_samples=130, prob=(0.338, 0.662)
|--- |--- |--- |--- |--- (x1 < 8.019). label=0 n_samples=44, prob=(1.000, 0.000)
|--- |--- |--- |--- |--- (x1 > 8.019). label=1 n_samples=86, prob=(0.000, 1.000)
위 결과에서 살펴볼 수 있듯이 처음 를 통하여 아래 부분은 큰 어려움없이 분류가 잘 됩니다. 대부분의 depth 를 삼각형 모양의 경계면을 학습하는데 이용하고 있음을 확인할 수 있습니다.