Random Forest로 손글씨 분류하기

2024. 3. 12. 11:28·Machine Learning/Decision Tree

 

 

 
 

Random Forest로 손글씨 분류하기¶

 
In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


np.random.seed(2021)
 
 

1. Data¶

 
 

1.1 Data Load¶

손글씨 데이터는 0~9 까지의 숫자를 손으로 쓴 데이터입니다.
데이터는 sklearn.datasets의 load_digits 를 이용해 받을 수 있습니다.

 
In [2]:
from sklearn.datasets import load_digits

digits = load_digits()
 
In [3]:
data, target = digits["data"], digits["target"]
 
 

1.2 Data EDA¶

 
 

데이터는 각 픽셀의 값을 나타냅니다.

 
In [4]:
data[0], target[0]
 
Out[4]:
(array([ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.,  0.,  0., 13., 15., 10.,
        15.,  5.,  0.,  0.,  3., 15.,  2.,  0., 11.,  8.,  0.,  0.,  4.,
        12.,  0.,  0.,  8.,  8.,  0.,  0.,  5.,  8.,  0.,  0.,  9.,  8.,
         0.,  0.,  4., 11.,  0.,  1., 12.,  7.,  0.,  0.,  2., 14.,  5.,
        10., 12.,  0.,  0.,  0.,  0.,  6., 13., 10.,  0.,  0.,  0.]),
 0)
 
 

데이터의 크기를 확인하면 64인데 이는 8*8 이미지를 flatten 시켰기 때문입니다.

 
In [5]:
data[0].shape
 
Out[5]:
(64,)
 
 

실제로 0부터 9까지의 데이터를 시각화하면 다음과 같이 나타납니다.

 
In [6]:
samples = data[:10].reshape(10, 8, 8)
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(20, 10))
for idx, sample in enumerate(samples):
    axes[idx//5, idx%5].imshow(sample, cmap="gray")
 
 
 
In [41]:
samples
 
Out[41]:
array([[[ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.],
        [ 0.,  0., 13., 15., 10., 15.,  5.,  0.],
        [ 0.,  3., 15.,  2.,  0., 11.,  8.,  0.],
        [ 0.,  4., 12.,  0.,  0.,  8.,  8.,  0.],
        [ 0.,  5.,  8.,  0.,  0.,  9.,  8.,  0.],
        [ 0.,  4., 11.,  0.,  1., 12.,  7.,  0.],
        [ 0.,  2., 14.,  5., 10., 12.,  0.,  0.],
        [ 0.,  0.,  6., 13., 10.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0., 12., 13.,  5.,  0.,  0.],
        [ 0.,  0.,  0., 11., 16.,  9.,  0.,  0.],
        [ 0.,  0.,  3., 15., 16.,  6.,  0.,  0.],
        [ 0.,  7., 15., 16., 16.,  2.,  0.,  0.],
        [ 0.,  0.,  1., 16., 16.,  3.,  0.,  0.],
        [ 0.,  0.,  1., 16., 16.,  6.,  0.,  0.],
        [ 0.,  0.,  1., 16., 16.,  6.,  0.,  0.],
        [ 0.,  0.,  0., 11., 16., 10.,  0.,  0.]],

       [[ 0.,  0.,  0.,  4., 15., 12.,  0.,  0.],
        [ 0.,  0.,  3., 16., 15., 14.,  0.,  0.],
        [ 0.,  0.,  8., 13.,  8., 16.,  0.,  0.],
        [ 0.,  0.,  1.,  6., 15., 11.,  0.,  0.],
        [ 0.,  1.,  8., 13., 15.,  1.,  0.,  0.],
        [ 0.,  9., 16., 16.,  5.,  0.,  0.,  0.],
        [ 0.,  3., 13., 16., 16., 11.,  5.,  0.],
        [ 0.,  0.,  0.,  3., 11., 16.,  9.,  0.]],

       [[ 0.,  0.,  7., 15., 13.,  1.,  0.,  0.],
        [ 0.,  8., 13.,  6., 15.,  4.,  0.,  0.],
        [ 0.,  2.,  1., 13., 13.,  0.,  0.,  0.],
        [ 0.,  0.,  2., 15., 11.,  1.,  0.,  0.],
        [ 0.,  0.,  0.,  1., 12., 12.,  1.,  0.],
        [ 0.,  0.,  0.,  0.,  1., 10.,  8.,  0.],
        [ 0.,  0.,  8.,  4.,  5., 14.,  9.,  0.],
        [ 0.,  0.,  7., 13., 13.,  9.,  0.,  0.]],

       [[ 0.,  0.,  0.,  1., 11.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  7.,  8.,  0.,  0.,  0.],
        [ 0.,  0.,  1., 13.,  6.,  2.,  2.,  0.],
        [ 0.,  0.,  7., 15.,  0.,  9.,  8.,  0.],
        [ 0.,  5., 16., 10.,  0., 16.,  6.,  0.],
        [ 0.,  4., 15., 16., 13., 16.,  1.,  0.],
        [ 0.,  0.,  0.,  3., 15., 10.,  0.,  0.],
        [ 0.,  0.,  0.,  2., 16.,  4.,  0.,  0.]],

       [[ 0.,  0., 12., 10.,  0.,  0.,  0.,  0.],
        [ 0.,  0., 14., 16., 16., 14.,  0.,  0.],
        [ 0.,  0., 13., 16., 15., 10.,  1.,  0.],
        [ 0.,  0., 11., 16., 16.,  7.,  0.,  0.],
        [ 0.,  0.,  0.,  4.,  7., 16.,  7.,  0.],
        [ 0.,  0.,  0.,  0.,  4., 16.,  9.,  0.],
        [ 0.,  0.,  5.,  4., 12., 16.,  4.,  0.],
        [ 0.,  0.,  9., 16., 16., 10.,  0.,  0.]],

       [[ 0.,  0.,  0., 12., 13.,  0.,  0.,  0.],
        [ 0.,  0.,  5., 16.,  8.,  0.,  0.,  0.],
        [ 0.,  0., 13., 16.,  3.,  0.,  0.,  0.],
        [ 0.,  0., 14., 13.,  0.,  0.,  0.,  0.],
        [ 0.,  0., 15., 12.,  7.,  2.,  0.,  0.],
        [ 0.,  0., 13., 16., 13., 16.,  3.,  0.],
        [ 0.,  0.,  7., 16., 11., 15.,  8.,  0.],
        [ 0.,  0.,  1.,  9., 15., 11.,  3.,  0.]],

       [[ 0.,  0.,  7.,  8., 13., 16., 15.,  1.],
        [ 0.,  0.,  7.,  7.,  4., 11., 12.,  0.],
        [ 0.,  0.,  0.,  0.,  8., 13.,  1.,  0.],
        [ 0.,  4.,  8.,  8., 15., 15.,  6.,  0.],
        [ 0.,  2., 11., 15., 15.,  4.,  0.,  0.],
        [ 0.,  0.,  0., 16.,  5.,  0.,  0.,  0.],
        [ 0.,  0.,  9., 15.,  1.,  0.,  0.,  0.],
        [ 0.,  0., 13.,  5.,  0.,  0.,  0.,  0.]],

       [[ 0.,  0.,  9., 14.,  8.,  1.,  0.,  0.],
        [ 0.,  0., 12., 14., 14., 12.,  0.,  0.],
        [ 0.,  0.,  9., 10.,  0., 15.,  4.,  0.],
        [ 0.,  0.,  3., 16., 12., 14.,  2.,  0.],
        [ 0.,  0.,  4., 16., 16.,  2.,  0.,  0.],
        [ 0.,  3., 16.,  8., 10., 13.,  2.,  0.],
        [ 0.,  1., 15.,  1.,  3., 16.,  8.,  0.],
        [ 0.,  0., 11., 16., 15., 11.,  1.,  0.]],

       [[ 0.,  0., 11., 12.,  0.,  0.,  0.,  0.],
        [ 0.,  2., 16., 16., 16., 13.,  0.,  0.],
        [ 0.,  3., 16., 12., 10., 14.,  0.,  0.],
        [ 0.,  1., 16.,  1., 12., 15.,  0.,  0.],
        [ 0.,  0., 13., 16.,  9., 15.,  2.,  0.],
        [ 0.,  0.,  0.,  3.,  0.,  9., 11.,  0.],
        [ 0.,  0.,  0.,  0.,  9., 15.,  4.,  0.],
        [ 0.,  0.,  9., 12., 13.,  3.,  0.,  0.]]])
 
 

1.3 Data split¶

 
 

데이터를 Train, Test로 나누겠습니다.

 
In [7]:
from sklearn.model_selection import train_test_split

train_data, test_data, train_target, test_target = train_test_split(
    data, target, train_size=0.7, random_state=2021
)
 
In [8]:
print(f"train_data size: {len(train_target)}, {len(train_target)/len(data):.2f}")
print(f"test_data size: {len(test_target)}, {len(test_target)/len(data):.2f}")
 
 
train_data size: 1257, 0.70
test_data size: 540, 0.30
 
 

2. Random Forest¶

 
In [9]:
from sklearn.ensemble import RandomForestClassifier


random_forest = RandomForestClassifier()
 
 

2.1 학습¶

 
In [10]:
random_forest.fit(train_data, train_target)
 
Out[10]:
RandomForestClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier()
 
 

2.2 Feature Importance¶

 
In [11]:
random_forest.feature_importances_
 
Out[11]:
array([0.00000000e+00, 2.22027389e-03, 2.01089999e-02, 1.07830476e-02,
       9.51528198e-03, 2.36360479e-02, 9.31093623e-03, 6.69196714e-04,
       8.85114030e-06, 8.82351373e-03, 2.63886754e-02, 6.78867334e-03,
       1.54145479e-02, 2.58694831e-02, 5.88021424e-03, 4.15847662e-04,
       0.00000000e+00, 6.79631151e-03, 2.29701107e-02, 2.65587908e-02,
       3.23883747e-02, 4.61493602e-02, 9.64852734e-03, 2.12386017e-04,
       1.65085512e-05, 1.53382481e-02, 4.21053792e-02, 2.57443328e-02,
       3.07340934e-02, 1.89734748e-02, 3.20272649e-02, 6.33666089e-05,
       0.00000000e+00, 3.04673193e-02, 2.16692479e-02, 2.20796421e-02,
       3.91744952e-02, 1.99743626e-02, 2.36173762e-02, 0.00000000e+00,
       1.17972340e-05, 1.03136921e-02, 3.92388221e-02, 3.67256272e-02,
       2.11985343e-02, 2.03332950e-02, 1.74046206e-02, 8.96174336e-05,
       4.31000081e-05, 2.92475245e-03, 1.62593325e-02, 2.46582747e-02,
       1.44533638e-02, 2.65565338e-02, 2.35906893e-02, 1.84734419e-03,
       7.11059765e-05, 2.59587724e-03, 2.17536511e-02, 1.20310587e-02,
       2.69356351e-02, 2.94907310e-02, 1.61051482e-02, 2.82483197e-03])
 
In [12]:
feature_importance = pd.Series(random_forest.feature_importances_)
 
In [13]:
feature_importance.head(10)
 
Out[13]:
0    0.000000
1    0.002220
2    0.020109
3    0.010783
4    0.009515
5    0.023636
6    0.009311
7    0.000669
8    0.000009
9    0.008824
dtype: float64
 
In [14]:
feature_importance = feature_importance.sort_values(ascending=False)
 
In [15]:
feature_importance.head(10)
 
Out[15]:
21    0.046149
26    0.042105
42    0.039239
36    0.039174
43    0.036726
20    0.032388
30    0.032027
28    0.030734
33    0.030467
61    0.029491
dtype: float64
 
In [16]:
feature_importance.head(10).plot(kind="barh")
 
Out[16]:
<AxesSubplot: >
 
 
In [17]:
image = random_forest.feature_importances_.reshape(8, 8)

plt.imshow(image, cmap=plt.cm.hot, interpolation="nearest")
cbar = plt.colorbar(ticks=[random_forest.feature_importances_.min(), random_forest.feature_importances_.max()])
cbar.ax.set_yticklabels(['Not Important', 'Very Important'])
plt.axis("off")
 
Out[17]:
(-0.5, 7.5, 7.5, -0.5)
 
 
 

2.3 예측¶

 
In [18]:
train_pred = random_forest.predict(train_data)
test_pred = random_forest.predict(test_data)
 
 

실제 데이터를 한 번 그려보겠습니다.

 
In [19]:
plt.imshow(train_data[4].reshape(8, 8), cmap="gray")
 
Out[19]:
<matplotlib.image.AxesImage at 0x1ba2cafb610>
 
 
 

이 데이터에 대한 값을 보면 9로 잘 나오는 것을 볼 수 있습니다.

 
In [20]:
train_pred[4]
 
Out[20]:
9
 
 

2.4 평가¶

 
In [21]:
from sklearn.metrics import accuracy_score

train_acc = accuracy_score(train_target, train_pred)
test_acc = accuracy_score(test_target, test_pred)
 
In [22]:
print(f"train accuracy is {train_acc:.4f}")
print(f"test accuracy is {test_acc:.4f}")
 
 
train accuracy is 1.0000
test accuracy is 0.9667
 
 

3. Best Hyper Parameter¶

 
 

RandomForestClassifier에서 주로 탐색하는 argument들은 다음과 같습니다.

  • n_estimators
    • 몇 개의 나무를 생성할 것 인지 정합니다.
  • criterion
    • 어떤 정보 이득을 기준으로 데이터를 나눌지 정합니다.
    • "gini", "entropy"
  • max_depth
    • 나무의 최대 깊이를 정합니다.
  • min_samples_split
    • 노드가 나눠질 수 있는 최소 데이터 개수를 정합니다.
 
 

탐색해야할 argument들이 많을 때 일일이 지정을 하거나 for loop을 작성하기 힘들어집니다.
이 때 사용할 수 있는 것이 sklearn.model_selection의 GridSearchCV 함수입니다.

 
In [23]:
from sklearn.model_selection import GridSearchCV
 
 

3.1 탐색 범위 선정¶

 
 

탐색할 값들의 argument와 범위를 정합니다.

 
In [24]:
params = {
    "n_estimators": [i for i in range(100, 1000, 200)],
    "max_depth": [i for i in range(10, 50, 10)],
}
 
In [25]:
params
 
Out[25]:
{'n_estimators': [100, 300, 500, 700, 900], 'max_depth': [10, 20, 30, 40]}
 
 

탐색에 사용할 모델을 생성합니다.

 
In [26]:
random_forest = RandomForestClassifier()
 
 

3.2 탐색¶

 
 

탐색을 시작합니다.
cv는 k-fold의 k값입니다.

 
In [27]:
grid = GridSearchCV(estimator=random_forest, param_grid=params, cv=3)
grid = grid.fit(train_data, train_target)
 
 

3.3 결과¶

 
In [28]:
print(f"Best score of paramter search is: {grid.best_score_:.4f}")
 
 
Best score of paramter search is: 0.9730
 
In [29]:
grid.best_params_
 
Out[29]:
{'max_depth': 30, 'n_estimators': 300}
 
In [30]:
print("Best parameter of best score is")
print(f"\t max_depth: {grid.best_params_['max_depth']}")
print(f"\t n_estimators: {grid.best_params_['n_estimators']}")
 
 
Best parameter of best score is
	 max_depth: 30
	 n_estimators: 300
 
In [31]:
best_rf = grid.best_estimator_
 
In [32]:
best_rf
 
Out[32]:
RandomForestClassifier(max_depth=30, n_estimators=300)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(max_depth=30, n_estimators=300)
 
 

3.4 예측¶

 
In [33]:
train_pred = best_rf.predict(train_data)
test_pred = best_rf.predict(test_data)
 
 

3.5 평가¶

 
In [34]:
best_train_acc = accuracy_score(train_target, train_pred)
best_test_acc = accuracy_score(test_target, test_pred)
 
In [35]:
print(f"Best parameter train accuracy is {best_train_acc:.4f}")
print(f"Best parameter test accuracy is {best_test_acc:.4f}")
 
 
Best parameter train accuracy is 1.0000
Best parameter test accuracy is 0.9704
 
In [36]:
print(f"train accuracy is {train_acc:.4f}")
print(f"test accuracy is {test_acc:.4f}")
 
 
train accuracy is 1.0000
test accuracy is 0.9667
 
 

4. Feature Importance¶

 
In [37]:
best_feature_importance = pd.Series(best_rf.feature_importances_)
 
In [38]:
best_feature_importance = best_feature_importance.sort_values(ascending=False)
 
In [39]:
best_feature_importance.head(10)
 
Out[39]:
21    0.045323
26    0.044439
43    0.043824
36    0.040624
42    0.035054
28    0.033223
20    0.030152
30    0.029488
27    0.028930
60    0.028078
dtype: float64
 
In [40]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
feature_importance.head(10).plot(kind="barh", ax=axes[0], title="Random Forest Feature Importance")
best_feature_importance.head(10).plot(kind="barh", ax=axes[1], title="Best Parameter Feature Importance")
 
Out[40]:
<AxesSubplot: title={'center': 'Best Parameter Feature Importance'}>
 
 
In [ ]:
 

'Machine Learning > Decision Tree' 카테고리의 다른 글

부동산 가격 예측  (0) 2024.03.12
Decision Tree Regressor  (0) 2024.03.12
Iris 꽃 종류 분류  (0) 2024.03.12
Decision Tree Classification 기초  (0) 2024.03.12
'Machine Learning/Decision Tree' 카테고리의 다른 글
  • 부동산 가격 예측
  • Decision Tree Regressor
  • Iris 꽃 종류 분류
  • Decision Tree Classification 기초
Juson
Juson
  • Juson
    Juson의 데이터 공부
    Juson
  • 전체
    오늘
    어제
    • 분류 전체보기 (95)
      • RAG (2)
      • AI (2)
        • NLP (0)
        • Generative Model (0)
        • Deep Reinforcement Learning (2)
        • LLM (0)
      • Logistic Optimization (0)
      • Machine Learning (37)
        • Linear Regression (2)
        • Logistic Regression (2)
        • Decision Tree (5)
        • Naive Bayes (1)
        • KNN (2)
        • SVM (2)
        • Clustering (4)
        • Dimension Reduction (3)
        • Boosting (6)
        • Abnomaly Detection (2)
        • Recommendation (4)
        • Embedding & NLP (4)
      • Reinforcement Learning (5)
      • Deep Learning (10)
        • Deep learning Bacis Mathema.. (10)
      • Optimization (2)
        • OR Optimization (0)
        • Convex Optimization (0)
        • Integer Optimization (0)
      • SNA 분석 (0)
      • 포트폴리오 최적화 공부 (0)
        • 최적화 기법 (0)
        • 금융 베이스 (0)
      • Finanancial engineering (0)
      • 프로그래머스 데브코스(Boot camp) (15)
        • SQL (9)
        • Python (5)
        • Machine Learning (1)
      • Python (22)
      • Project (0)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.4
Juson
Random Forest로 손글씨 분류하기
상단으로

티스토리툴바