티스토리 뷰

Sklearn

Sklearn

헷갈리는 module 정리.

Sklearn version : v0.20.2 기준 작성.

1. sklearn.model_selection.KFold

텍스트?

K 개의 subsample들 (fold) 로 나누고 index를 반환해준다.

사용 예시를 보자.

from sklearn.model_selection import KFold
import numpy as np

X = np.arange(16).reshape((8,-1))
y = np.arange(8).reshape((-1,1))

kf = KFold(n_splits=4)

for train_index, test_index in kf.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

output은 다음과 같다.

TRAIN: [2 3 4 5 6 7] TEST: [0 1]
TRAIN: [0 1 4 5 6 7] TEST: [2 3]
TRAIN: [0 1 2 3 6 7] TEST: [4 5]
TRAIN: [0 1 2 3 4 5] TEST: [6 7]

shuffle 하면 index를 섞는다..

참고 : sklearn KFold API

2. sklearn.model_selection.StratifiedKFold

stratifiedlabel 의 분포를 유지시켜준다고 생각하면 된다. 즉, 각 fold 안의 데이터셋의 label 분포가 전체 데이터셋의 label 분포를 따른다.

다시 말해서, 각 fold가 전체 데이터셋을 잘 대표한다.

모델을 학습시킬 때 편향되지 않게 학습시킬 수 있다.

텍스트

예제를 보자...

X = np.arange(12*2).reshape((12,-1))
y = np.array([0,0,1,2,1,0,0,0,0,1,2,2])

skf = StratifiedKFold(n_splits=3)

for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)

output은 다음과 같다.

TRAIN: [ 4  5  6  7  8  9 10 11] TEST: [ 0  1  2  3]
TRAIN: [ 0  1  2  3  7  8  9 11] TEST: [ 4  5  6 10]
TRAIN: [ 0  1  2  3  4  5  6 10] TEST: [ 7  8  9 11]

먼저 y가 [0,0,1,2,1,0,0,0,0,1,2,2] 이다.

숫자를 어떤 label 이라 생각하고 분포를 살펴보자.

0 1 2
6 3 3

2:1:1 의 분포를 가지고 있다.

Fold 들도 각각 label 0, label 1, label 2를 2:1:1 로 가져야 될 것이다.

텍스트

각 test의 index를 따라가면 세 번의 split 결과 모두 test set의 label 분포가 같음을 확인할 수 있다. KFOLD로 쪼갰으면 분포가 달라졌을 것이다.

for train_index, test_index in skf.split(X, y) 에서 y를 꼭 데이터의 output, 즉, class 로 설정 안해도 된다.

여기서 label을 output의 label이라 생각하지말고, 좀더 일반적인 feature라고 생각하자.

예를 들면, 어떤 특성 (feature) f 가 전체 데이터를 잘 대표할 수도 있고, 이를 통해 데이터를 split하고 싶다면skf.split(X, y) 에서 y 대신 X["f"] 를 입력하자.

참고 : sklearn StratifiedKFold API

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함