티스토리 뷰

FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

코드: GitHub - google-research/fixmatch: A simple method to perform semi-supervised learning with limited data.

 

google-research/fixmatch

A simple method to perform semi-supervised learning with limited data. - google-research/fixmatch

github.com

 

1. Introduction

딥러닝 모델이 괜찮은 결과를 내기 위해서는 라벨 데이터의 개수가 많아야 한다.

ex. GPT-3의 경우, 175B (=1750억)개의 파라미터를 사용한다. 

 

2. Semi-Supervised Learning

그래서 사용하는 것이 Semi-Supervised Learning 기법이다.

더보기

우선 Supervised Learning 지도학습은 정답이 있는 데이터로 학습을 하여 어떤 데이터를 분류하거나 (분류, Classification) / 값을 예측하는 것 (회귀, Regression) 이다.

Unsupervised Learning 비지도학습은 정답을 따로 알려주지 않고, 비슷한 데이터들을 군집화(clustering)하는 것이다.

Reinforcement Learning 강화학습은 상과 벌이라는 보상(reward)을 주며, 상을 최대화하고 벌을 최소화하도록 학습하는 방식이다.

준 지도학습은, 지도학습과 비지도학습을 결합한 방식이다.

즉, 레이블이 있는 데이터와 레이블이 없는 데이터를 모두 이용하여 학습을 수행한다.

참고: blog.est.ai/2020/11/ssl/

 

Semi-supervised learning 방법론 소개

안녕하세요. 이스트소프트 A.I. PLUS Lab입니다. 이번 포스팅에서는 머신러닝의 학습 방법 중 하나인 준지도학습(semi-supervised learning, SSL)에 대해 다루어보려고 합니다. SSL 자체가 워낙 거대한 주제

blog.est.ai

FixMatch방식

labeled 이미지는 우리가 아는 방식으로, 모델을 학습시키는 데 사용한다.

 

1) Pseudo Labeling방식

unlabeled input을 모델에 넣어서 prediction을 계산한다.

-> maximum Softmax probability가 threshold이상이면 one-hot encoding을 통해 pseudo label을 만든다.

-> 모델의 prediction과 pseudo label간에 Cross-entropy loss를 최소화시키도록 학습을 시킨다.

 

2) Consistency Regularization방식

unlabeled input을 서로 다른 augmentation을 주어서 서로 다른 두 이미지를 만든다. 모델에 input시킴.

-> 그래서 나온 두 prediction이 같아야 한다는 아이디어. l2 loss 또는 cross-entropy loss를 최소화시키도록 학습을 시킨다.

 

3) FixMatch방식 = ( Pseudo Labeling방식 + Consistency Regularization방식 )

unlabeled image에 대해서 하나는 weak augmentation 적용한다. 뒤집거나 미는 방식. 해당 이미지를 모델에 넣어서 prediction을 뽑은 다음, Pseudo Labeling방식을 적용한다. 즉 prediction에 thresholding을 적용한 다음, 넘으면 one-hot encoding을 통해 pseudo label을 만든다.

다른 하나는 strong augmentation 적용한다. 밀고, 대비를 올리는 등 여러 옵션의 augmentation을 생성하고 그 중에서 랜덤으로 선택해 RandAugment(RA) 또는 ControlTheoryAugment(CTA) 방식을 사용한다. CTA는 RA와 같은 방식인데, 각 옵션에 적용되는 magnitude 또한 다이나믹하게 가져오는 차이가 있다. RA 또는 CTA를 적용한 후에 Cut-out Augmentation을 적용한다. 이후 모델에 넣어 prediction을 구한다.

윗줄(weak augmentation)에서 구한 pseudo label과 아랫줄(strong augmentation)에서 구한 prediction간cross-entropy를 최소화하도록 학습한다. 

 

3. Experiments

CIFAR-10과 CIFAR-100, SVHN 데이터셋을 이용해서 실험을 진행하였다.

더보기

각 데이터셋은 이미지 분류를 위한 데이터셋 중 유명한 것들이다.

참고: kjhov195.github.io/2020-02-09-image_dataset_1/

 

Image Dataset(1)

Deep Learning

kjhov195.github.io

결과 error-rate는 다음과 같다.

: Error rates for CIFAR-10, CIFAR-100 and SVHN on 5 different folds

해당 논문에서 제안하는 FixMatch의 방식이 기존 방식들에 비해서 성능이 크게 떨어지지 않고, 오히려 개선되기도 한 것을 확인할 수 있다.

특히 흥미로운 것은 10개의 레이블에 대해서 딱 한 장의 이미지만을 주고 학습을 시켰을 때 CIFAR-10의 정확도가 78%에 이른다는 것이다.

4. Application

스터디원이 FixMatch방식을 직접 적용해본 프로젝트를 소개해주었다.

Naver Fashion Dataset을 가지고, 상품들을 분류하는 작업을 진행하였다.

동물을 분류하는 경우에는 모두 눈이 2개이고, 다리가 대부분 4개인 등 대부분 비슷하여 분류 threshold값이 높아야 하지만, 상품은 품목별로 모두 다르게 생겼으니 쉬울 것이라 예상하였다. ex. 참치캔과 볼펜을 헷갈리지는 않을 것.

하지만, 학습에서 오버피팅이 발생하여 Training Set에 대해서는 100%의 정확도를 가지지만, Validation Set에서는 11%의 정확도에 그쳤다.

이유는, threshold값을 넘어야만 unlabeled 이미지를 사용했기 때문에, 어떠한 레이블에 대해서 고정된 패턴만을 학습했기 때문이라고 하셨다.

이에 대한 해결법은, 모델을 학습할 때, 2 epoch만 labeled 이미지를 이용하여 학습하고, 이후에는 unlabeled 이미지로 학습을 진행하면 나아질 것이라고 하셨다.

 

5. 참고

- 논문스터디

- www.youtube.com/watch?v=fOCxgrR95ew&feature=youtu.be

 

'논문 리뷰' 카테고리의 다른 글

Introduction  (0) 2021.01.01
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
글 보관함