(1) K-최근접 이웃 분류
지도학습 알고리즘에는 K-최근접이웃(KNN), 선형모델, 나이브 베이즈 분류기, 결정 트리 및 앙상블, 커널 서포트 벡터 머신, 신경망 등이 있다. 그중에서 오늘은 KNN알고리즘에 대해 알아보겠다.
K-최근접이웃은 새로 들어온 데이터에 가장 가까운 k개의 데이터를 찾아서, 새로 들어온 데이터의 값을 결정하는 것이다.
from preamble import *
mglearn.plots.plot_knn_classification(n_neighbors=1)
다음은 데이터셋에 대한 1-최근접 이웃 모델을 예측한 것이다. 이제 k를
n_neighbors=1
부분을 변화하면서 출력이 어떻게 바뀌는지 관찰해보겠다.
mglearn.plots.plot_knn_classification(n_neighbors=3)
이것은 3-최근접 이웃모델 예측으로, n_neighbors=1
일때와 n_neighbors=3
일때의 차이를 시각적으로 확인할 수 있다. 또한, 이웃을 하나만 사용했을 때와 셋을 사용했을 때의 예측이 달라진 것도 확인할 수 있다.
이 그림은 이진 분류이지만 다중 분류에서도 같은 방법을 적용할 수 있다. 또한, 참고로 k는 홀수를 자주 쓴다.
이제 scikit-learn을 사용해서 k-최근접 이웃 알고리즘을 어떻게 적용하는지 알아보겠다.
from sklearn.model_selection import train_test_split #test,validation, test로 나누는 함수
x,y=mglearn.datasets.make_forge()
print(x.shape)
print(y.shape)
x_train,x_test,y_train,y_test=train_test_split(x,y,random_state=0) #random하게 섞어서 나누기
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)
out
(26, 2)
(26,)
(19, 2)
(7, 2)
(19,)
(7,)
일반화 성능을 평가할 수 있도록 데이터를 훈련 세트와 테스트 세트로 나누는 과정이다. 이제, KNeighborsClassifier를 import하고 객체를 만든다. 이 과정에서 매개변수를 지정해아 하는데, 3으로 지정하고 진행하겠다.
from sklearn.neighbors import KNeighborsClassifier
clf=KNeighborsClassifier(n_neighbors=3)
clf.fit(x_train,y_train) #fit =학습하라는 의미
out
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=None, n_neighbors=3, p=2,
weights='uniform')
훈련 세트를 이용해서 분류 모델을 학습시킨다. 이제 predict를 호출해서 예측을 진행한다. 테스트 세트의 각 데이터에 대해 훈련 세트에서 가장 가까운 이웃을 계산한 다음 가장 많은 클래스를 찾는다.
y_pred=clf.predict(x_test)
print("테스트 세트 예측:",y_pred)
print("테스트 세트 참값:",y_test)
out
테스트 세트 예측: [1 0 1 0 1 0 0]
테스트 세트 참값: [1 0 1 0 1 1 0]
모델이 일반화된 정도를 알기 위해 score를 이용해준다.
print("테스트 세트 정확도: {:.2f}".format(clf.score(x_test,y_test)))
out
테스트 세트 정확도: 0.86
K-NN은 학습이 필요없다. K-NN자체가 트레이닝 데이터가 있는 상태에서 새로운 데이터가 들어오면 트레이닝 데이터 자체를 사용해서 근접 이웃을 찾는 것이기 때문이다. 성능을 찾기 위해서는 테스트 데이터셋을 이용하는데, 예측값과 정답이 같은지를 비교하며 성능을 측정한다.
(2)KNeighborsClassifier 분석
이차원 데이터셋일 경우 xy평면에 테스트 포인트 예측을 그려볼 수 있다.
fig, axes=plt.subplots(1,3,figsize=(10,3)) #subplot은 그림을 여러개 그릴때 사용한다.
for n_neighbors, ax in zip([1,3,9],axes):
# fit 메소드는 self 오브젝트를 리턴합니다
# 그래서 객체 생성과 fit 메소드를 한 줄에 쓸 수 있습니다.
clf=KNeighborsClassifier(n_neighbors=n_neighbors).fit(x,y) #아까위에 있던 구조를 한줄로 쓴 것
mglearn.plots.plot_2d_separator(clf,x,fill=True,eps=0.5,ax=ax,alpha=.4)
mglearn.discrete_scatter(x[:,0],x[:,1],y,ax=ax)
ax.set_title("{} neighbor".format(n_neighbors))
ax.set_xlabel("feature0")
ax.set_ylabel("feature1")
axes[0].legend(loc='best')
그림을 보면 빨강색 부분과 파랑색 부분의 경계를 관찰할 수 있다. 이를 결정 경계decision boundary라고 한다. 우리는 위의 그림을 통해 k값이 클수록 경계가 완만해지는 것을 알 수 있다. 반대로 k가 작을수록 구불구불해짐을 알 수 있다. 정확도는 k=1일때가 가장 높지만 이것을 train data에 대한 정확도이므로 좋지 않다. k=9일때는 모델이 너무 복잡해지므로 3정도가 가장 적당하다.
이제 유방암 데이터셋을 사용해서 K-NN을 실습해보겠다.
from sklearn.datasets import load_breast_cancer
cancer=load_breast_cancer()
x_train,x_test,y_train,y_test=train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=66)
training_accuracy=[]
test_accuracy=[]
#1에서 10까지 n_neighbors를 적용
neighbors_settings=range(1,11) #k를 1부터 10까지
for n_neighbors in neighbors_settings:
#모델 생성
clf=KNeighborsClassifier(n_neighbors=n_neighbors)
clf.fit(x_train,y_train)
#훈련 세트 정확도 저장
training_accuracy.append(clf.score(x_train,y_train))
#일반화 정확도 저장
test_accuracy.append(clf.score(x_test,y_test))
plt.plot(neighbors_settings,training_accuracy,label="train accuracy")
plt.plot(neighbors_settings,test_accuracy,label="test accuracy")
plt.ylabel("accuracy")
plt.xlabel("n_neighbors")
plt.legend()
이 그림은 이웃 수에 따른 훈련 세트와 테스트 세트의 정확도를 보여준다. 이 그래프를 통해 과대적합과 과소적합의 특징을 볼 수 있다. 이웃의 수가 하나일 때는, 훈련 데이터에 대한 예측이 높다. 그러나 테스트 세트의 정확도는 낮다. 이는 모델을 너무 복잡하게 만들었기 때문이다. 그렇다면 이웃의 수가 많아지면 어떨까? 이웃의 수가 10개일 경우에는 모델이 너무 단순해지기 때문에 정확도가 더 떨어진다. 그러므로, 정확도는 6개 정도일 때가 가장 적당하다.
(3) k-Neighbors Regression
k-최근접 이웃 알고리즘은 회귀 분석에도 사용한다. 그럼 관련된 예제를 실습해보겠다.
mglearn.plots.plot_knn_regression(n_neighbors=1)
참고로 그림그리는 예제에서 빨간색 오류가 엄청 뜨는 경우가 있는데, 이는 관련 폰트가 없어서 뜨는 오류로 무시해도 괜찮다. 다만 위의 그림처럼 글자가 아니라 네모칸이 뜰 것이다. 상관은 없지만, 이해를 위해 타블로그에서 사진을 가져와서 올리도록 하겠다.
사진은 https://tensorflow.blog/%ed%8c%8c%ec%9d%b4%ec%8d%ac-%eb%a8%b8%ec%8b%a0%eb%9f%ac%eb%8b%9d/2-3-2-k-%ec%b5%9c%ea%b7%bc%ec%a0%91-%ec%9d%b4%ec%9b%83/ 이곳에서 가져왔다.
다음은 이웃이 3일 때의 경우를 알아보겠다.
mglearn.plots.plot_knn_regression(n_neighbors=3)
녹색 별과 가까운 점 3개의 feature 값을 보고 그 3개의 y값의 평균을 채택하는 방식이다. 이제 scikit-learn을 이용해서 실습을 할 것인데, 방법은 위에서 했던 KNeighborsClassifier와 비슷하다.
from sklearn.neighbors import KNeighborsRegressor
x,y=mglearn.datasets.make_wave(n_samples=40)
#wave데이터셋을 훈련 세트와 테스트 세트로 나눕니다
x_train,x_test,y_train,y_test=train_test_split(x,y,random_state=0)
print(x_train.shape)
print(x_test.shape)
#이웃의 수를 3으로 하여 모델의 객체를 만듭니다
reg=KNeighborsRegressor(n_neighbors=3)
#훈련 데이터와 타깃을 사용하여 모델을 학습시킵니다.
reg.fit(x_train,y_train)
out
(30, 1)
(10, 1)
KNeighborsRegressor(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=None, n_neighbors=3, p=2,
weights='uniform')
테스트 세트 예측을 시작해본다.
print("테스트 세트 예측:\n",reg.predict(x_test))
print("테스트 세트 R^2: {:.2f}".format(reg.score(x_test,y_test)))
out
테스트 세트 예측:
[-0.054 0.357 1.137 -1.894 -1.139 -1.631 0.357 0.912 -0.447 -1.139]
테스트 세트 R^2: 0.83
(4) KNeighborsRegressor 분석
위에서 했던 것처럼 그림으로 표현해보겠다.
fig, axes=plt.subplots(1,3,figsize=(15,4))
#-3과 3 사이에 1000개의 데이터 포인트를 만듭니다.
line=np.linspace(-3,3,1000).reshape(-1,1)
for n_neighbors, ax in zip([1,3,9],axes):
#1,3,9 이웃을 사용한 예측을 합니다
reg=KNeighborsRegressor(n_neighbors=n_neighbors)
reg.fit(x_train,y_train)
ax.plot(line,reg.predict(line))
ax.plot(x_train,y_train,'^',c=mglearn.cm2(0),markersize=8)
ax.plot(x_test,y_test, 'v',c=mglearn.cm2(1),markersize=8)
ax.set_title(
"{}-neighbor:{:.2f} test score:{:.2f}".format(
n_neighbors,reg.score(x_train,y_train),reg.score(x_test,y_test)))
ax.set_xlabel("feature")
ax.set_ylabel("target")
axes[0].legend(["model predictions","train data/target","test data/target"],loc="best")
이웃에 따라 최근버 이웃 회귀로 만들어진 예측을 비교하는 것이다. 그림을 보면 알 수 있듯이 k가 높을면 높아질수록 부드러워짐을 알 수 있다. 또한, k가 작으면 각 훈련 세트가 예측에 주는 영향이 크기 때문에 불안정한 예측이 나타나고, 이웃이 많아지면 훈련 데이터에는 잘 안맞을 수 있지만 더 안정된 예측을 한다.
지금까지의 실습에서 알 수 있는 점은 K-NN은 이해하기 쉬운 모델이라는 것이다. 또한, 많은 매개변수를 조정하지 않아도 좋은 성능을 발휘한다. 그러나 훈련 세트가 크면 예측이 느려진다는 점과 데이터 전처리 과정이 필요한 것 등의 단점 때문에 현업에서는 잘 쓰지 않는다.