티스토리 뷰
[논문 리뷰] FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
donie 2020. 12. 6. 23:46FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
google-research/fixmatch
A simple method to perform semi-supervised learning with limited data. - google-research/fixmatch
github.com
1. Introduction
딥러닝 모델이 괜찮은 결과를 내기 위해서는 라벨 데이터의 개수가 많아야 한다.
ex. GPT-3의 경우, 175B (=1750억)개의 파라미터를 사용한다.
※ GPT-3는 무엇인가?
GPT-3는 OpenAI가 만든, 딥러닝을 이용해 인간다운 텍스트를 만들어내는 자기회귀 언어 모델이다.
GPT-3 - 나무위키
OpenAI사가 개발한 인공 일반 지능 모델. Generation Pre-trained Transformer 3(GPT-3)은 딥러닝을 이용해 인간다운 텍스트를 만들어내는 자기회귀 언어 모델이다. OpenAI사가 만든 GPT-n 시리즈의 3세대 언어 예
namu.wiki
2. Semi-Supervised Learning
그래서 사용하는 것이 Semi-Supervised Learning 기법이다.
우선 Supervised Learning 지도학습은 정답이 있는 데이터로 학습을 하여 어떤 데이터를 분류하거나 (분류, Classification) / 값을 예측하는 것 (회귀, Regression) 이다.
Unsupervised Learning 비지도학습은 정답을 따로 알려주지 않고, 비슷한 데이터들을 군집화(clustering)하는 것이다.
Reinforcement Learning 강화학습은 상과 벌이라는 보상(reward)을 주며, 상을 최대화하고 벌을 최소화하도록 학습하는 방식이다.
준 지도학습은, 지도학습과 비지도학습을 결합한 방식이다.
즉, 레이블이 있는 데이터와 레이블이 없는 데이터를 모두 이용하여 학습을 수행한다.
Semi-supervised learning 방법론 소개
안녕하세요. 이스트소프트 A.I. PLUS Lab입니다. 이번 포스팅에서는 머신러닝의 학습 방법 중 하나인 준지도학습(semi-supervised learning, SSL)에 대해 다루어보려고 합니다. SSL 자체가 워낙 거대한 주제
blog.est.ai
FixMatch방식
labeled 이미지는 우리가 아는 방식으로, 모델을 학습시키는 데 사용한다.
1) Pseudo Labeling방식
unlabeled input을 모델에 넣어서 prediction을 계산한다.
-> maximum Softmax probability가 threshold이상이면 one-hot encoding을 통해 pseudo label을 만든다.
-> 모델의 prediction과 pseudo label간에 Cross-entropy loss를 최소화시키도록 학습을 시킨다.
2) Consistency Regularization방식
unlabeled input을 서로 다른 augmentation을 주어서 서로 다른 두 이미지를 만든다. 모델에 input시킴.
-> 그래서 나온 두 prediction이 같아야 한다는 아이디어. l2 loss 또는 cross-entropy loss를 최소화시키도록 학습을 시킨다.
3) FixMatch방식 = ( Pseudo Labeling방식 + Consistency Regularization방식 )
unlabeled image에 대해서 하나는 weak augmentation 적용한다. 뒤집거나 미는 방식. 해당 이미지를 모델에 넣어서 prediction을 뽑은 다음, Pseudo Labeling방식을 적용한다. 즉 prediction에 thresholding을 적용한 다음, 넘으면 one-hot encoding을 통해 pseudo label을 만든다.
다른 하나는 strong augmentation 적용한다. 밀고, 대비를 올리는 등 여러 옵션의 augmentation을 생성하고 그 중에서 랜덤으로 선택해 RandAugment(RA) 또는 ControlTheoryAugment(CTA) 방식을 사용한다. CTA는 RA와 같은 방식인데, 각 옵션에 적용되는 magnitude 또한 다이나믹하게 가져오는 차이가 있다. RA 또는 CTA를 적용한 후에 Cut-out Augmentation을 적용한다. 이후 모델에 넣어 prediction을 구한다.
윗줄(weak augmentation)에서 구한 pseudo label과 아랫줄(strong augmentation)에서 구한 prediction간의 cross-entropy를 최소화하도록 학습한다.
3. Experiments
CIFAR-10과 CIFAR-100, SVHN 데이터셋을 이용해서 실험을 진행하였다.
각 데이터셋은 이미지 분류를 위한 데이터셋 중 유명한 것들이다.
참고: kjhov195.github.io/2020-02-09-image_dataset_1/
Image Dataset(1)
Deep Learning
kjhov195.github.io
결과 error-rate는 다음과 같다.
해당 논문에서 제안하는 FixMatch의 방식이 기존 방식들에 비해서 성능이 크게 떨어지지 않고, 오히려 개선되기도 한 것을 확인할 수 있다.
특히 흥미로운 것은 10개의 레이블에 대해서 딱 한 장의 이미지만을 주고 학습을 시켰을 때 CIFAR-10의 정확도가 78%에 이른다는 것이다.
4. Application
스터디원이 FixMatch방식을 직접 적용해본 프로젝트를 소개해주었다.
Naver Fashion Dataset을 가지고, 상품들을 분류하는 작업을 진행하였다.
동물을 분류하는 경우에는 모두 눈이 2개이고, 다리가 대부분 4개인 등 대부분 비슷하여 분류 threshold값이 높아야 하지만, 상품은 품목별로 모두 다르게 생겼으니 쉬울 것이라 예상하였다. ex. 참치캔과 볼펜을 헷갈리지는 않을 것.
하지만, 학습에서 오버피팅이 발생하여 Training Set에 대해서는 100%의 정확도를 가지지만, Validation Set에서는 11%의 정확도에 그쳤다.
이유는, threshold값을 넘어야만 unlabeled 이미지를 사용했기 때문에, 어떠한 레이블에 대해서 고정된 패턴만을 학습했기 때문이라고 하셨다.
이에 대한 해결법은, 모델을 학습할 때, 2 epoch만 labeled 이미지를 이용하여 학습하고, 이후에는 unlabeled 이미지로 학습을 진행하면 나아질 것이라고 하셨다.
5. 참고
- 논문스터디
- www.youtube.com/watch?v=fOCxgrR95ew&feature=youtu.be
'논문 리뷰' 카테고리의 다른 글
Introduction (0) | 2021.01.01 |
---|
- Total
- Today
- Yesterday
- 우분투
- 백준알고리즘
- Ubuntu20.04
- python3
- 초음파센서
- Python
- roslaunch
- vue/cli
- Publisher
- VMware
- 포트인식문제
- 8자주행
- sensehat
- 아두이노 IDE
- 리눅스
- Ubuntu16.04
- subscriber
- 코드리뷰
- ROS
- HC-SR04
- VirtualBox
- filesystem
- Mount
- 윈도우 복구
- 윈도우
- umount
- 프로그래머스
- 원격 통신
- C++
- set backspace
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |