Model Optimization

모델 최적화(하이퍼 파라미터 튜닝)

오버피팅

오버피팅이란 머신러닝에서 모델이 훈련 데이터에 지나치게 학습(fitting)되어 있어, 데이터에 대해 일반화 하는 능력이 떨어지는 현상을 말한다. 조금 더 풀어 설명하자면 모델이 훈련 데이터의 패턴 뿐 아니라 노이즈까지 학습해버려 실제 세계의 복잡성과 변동성을 반영하는데 실패하는 상태이다.

오버피팅의 발생 원인

  • 모델의 복잡도가 너무 높은 경우 모델의 복잡도는 모델이 가질 수 있는 학습 능력의 정도를 의미한다. 높은 복잡도의 모델은 많은 수의 파라미터를 가지는 동시에 더 복잡한 패턴을 학습할 수 있다. 다만, 훈련 데이터의 미묘한 패턴과 잡음까지 학습할 위험이 있다. 결론적으로, 복잡도가 너무 높은 모델은 훈련 데이터에는 높은 정확도를 보이지만, 새로운 데이터를 일반화하지 못하는 경향을 보인다.

  • 훈련 데이터의 양이 부족한 경우 훈련 데이터의 양이 충분하지 않다면 모델은 제한적 데이터에서 패턴을 학습하게 된다. 모델은 전체 데이터 분포를 대표할 수 없는 소수의 데이터 포인트에 과도하게 의존하여 학습하며 새로운 데이터에 대한 성능 저하를 유발한다.

  • 데이터에 잡음이 많거나 훈련 데이터의 특정 패턴을 지나치게 학습하는 경우

교차 검증(Cross-Validation)

검증 데이터는 머신러닝 모델을 개발할 때 모델의 성능을 평가하고 하이퍼 파라미터를 조정하는데 사용되는 데이터 세트를 말한다. 일반적으로 전체 데이터 세트를 훈련(Training) / 검증(Validation) / 테스트 (Test) 데이터로 나눈 후 사용한다. | 데이터 | 설명 | | ——- | ——————————————————————————————– | | 훈련 데이터 | 모델을 훈련시키기 위해 사용되는 데이터 세트이다. 모델은 훈련 데이터를 기반으로 패턴을 학습하고 파라미터를 조정한다. | | 검증 데이터 | 훈련 과정 중 모델의 성능을 평가하고, 하이퍼 파라미터를 조정하기 위해 사용되는 데이터 세트이다. 검증 데이터의 사용으로 훈련 데이터에 오버피팅 되는 것을 방지한다. | | 테스트 데이터 | 모델의 일반화 능력을 최종적으로 평가하기 위해 사용되는 데이터 세트이다. 테스트 데이터는 훈련 과정에서 모델에 전혀 노출되지 않으며, 모델의 일반화 성능을 측정한다. |

K-Fold Cross-Validation

K-Fold 교차 검증은 데이터를 K개의 동일한 크기의 부분집합(Fold)로 나눈 후 한 번씩 돌아가면서 검증데이터로 사용한다. K-Fold 교차 검증은 다음과 같은 과정을 통해 이루어진다.

  1. 테스트 데이터 분할 : 전체 데이터 셋에서 일정 비율을 테스트 데이터로 분리
  2. 폴드 분리 : 남은 데이터 세트를 K개의 동일한 크기를 갖는 폴드로 나눈다.
  3. 모델 훈련 및 검증 : 다음 과정을 K번 반복한다.
    1. 폴드 중 하나를 검증 데이터 세트로 선택
    2. 남은 폴드들을 결합하여 훈련 데이터 세트를 형성
    3. 모델을 훈련 데이터 세트로 피팅
    4. 피팅된 모델을 검증 데이터 세트에 적용하여 성능 평가
효과 설명
데이터 활용 극대화 테스트를 제외한 전체 데이터 셋을 훈련과 검증에 모두 사용하여 제한된 양의 데이터를 최대한 활용할 수 있다.
모델의 일반화 능력 평가 다양한 훈련과 검증 세트 조합을 사용해 모델을 평가함으로 모델의 일반화 능력에 대한 보다 신뢰할 수 있는 추정치를 얻을 수 있다.
오버피팅 감소 특정 훈련 세트에 오버피팅되는 것을 방지할 수 있다.
하이퍼 파라미터 튜닝의 용이성 다양한 하이퍼 파라미터 설정에 대해 K-Fold 교차검증을 반복적으로 수행함으로써, 최적의 모델 구성을 찾는 데 도움이 된다.



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • [CS231n]Exercise1.5 - Features
  • [CS231n]Exercise1.4 - Two Layer Net
  • [CS231n]Exercise1.3 - Softmax
  • [CS231n]Exercise1.2 - Support Vector Machine