티스토리 뷰
사이킷런 (scikit learn) 에서의 교차검증 (cross validation), Kfold 정리
수학수학 2018. 12. 22. 02:48Sklearn
헷갈리는 module 정리.
Sklearn version : v0.20.2 기준 작성.
1. sklearn.model_selection.KFold
K 개의 subsample들 (fold) 로 나누고 index를 반환해준다.
사용 예시를 보자.
from sklearn.model_selection import KFold
import numpy as np
X = np.arange(16).reshape((8,-1))
y = np.arange(8).reshape((-1,1))
kf = KFold(n_splits=4)
for train_index, test_index in kf.split(X):
print("TRAIN:", train_index, "TEST:", test_index)
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
output은 다음과 같다.
TRAIN: [2 3 4 5 6 7] TEST: [0 1]
TRAIN: [0 1 4 5 6 7] TEST: [2 3]
TRAIN: [0 1 2 3 6 7] TEST: [4 5]
TRAIN: [0 1 2 3 4 5] TEST: [6 7]
shuffle 하면 index를 섞는다..
참고 : sklearn KFold API
2. sklearn.model_selection.StratifiedKFold
stratified 는 label 의 분포를 유지시켜준다고 생각하면 된다. 즉, 각 fold 안의 데이터셋의 label 분포가 전체 데이터셋의 label 분포를 따른다.
다시 말해서, 각 fold가 전체 데이터셋을 잘 대표한다.
모델을 학습시킬 때 편향되지 않게 학습시킬 수 있다.
예제를 보자...
X = np.arange(12*2).reshape((12,-1))
y = np.array([0,0,1,2,1,0,0,0,0,1,2,2])
skf = StratifiedKFold(n_splits=3)
for train_index, test_index in skf.split(X, y):
print("TRAIN:", train_index, "TEST:", test_index)
output은 다음과 같다.
TRAIN: [ 4 5 6 7 8 9 10 11] TEST: [ 0 1 2 3]
TRAIN: [ 0 1 2 3 7 8 9 11] TEST: [ 4 5 6 10]
TRAIN: [ 0 1 2 3 4 5 6 10] TEST: [ 7 8 9 11]
먼저 y가 [0,0,1,2,1,0,0,0,0,1,2,2] 이다.
숫자를 어떤 label 이라 생각하고 분포를 살펴보자.
0 | 1 | 2 |
---|---|---|
6 | 3 | 3 |
2:1:1 의 분포를 가지고 있다.
Fold 들도 각각 label 0, label 1, label 2를 2:1:1 로 가져야 될 것이다.
for train_index, test_index in skf.split(X, y) 에서 y를 꼭 데이터의 output, 즉, class 로 설정 안해도 된다.
여기서 label을 output의 label이라 생각하지말고, 좀더 일반적인 feature라고 생각하자.
예를 들면, 어떤 특성 (feature) f 가 전체 데이터를 잘 대표할 수도 있고, 이를 통해 데이터를 split하고 싶다면skf.split(X, y) 에서 y 대신 X["f"] 를 입력하자.
- Total
- Today
- Yesterday
- 사이킷런 KFold
- variable
- pytorch
- Visual Studio Code에서 R
- 사이킷런
- vs code
- Pytorch .data
- Pytorch Variable
- sklearn.model_selection.KFold
- 박사과정 #PhD
- r
- Bit vector
- scikit learn
- 비트 벡터
- 교차검증
- 비쥬얼스튜디오코드
- sublime text
- vscode
- cross validation
- 파이토치
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |