대회를 진행하면서 멘토님께서 Contrastive Learning을 위해서는 Negative Sample을 잘 설계하는 것이 중요하다고 하셨다. Contrastive Learning에 대해서 잘 몰랐던 터라 이를 정리해봤다.
<목차> 1. Contrastive Learning 2. Contrastive Learning VS Similarity Learning 3. Distance Metric Learning - Metric 4. Deep Metric Learning - 등장 배경 - 설계 - Contrastive Loss, Triplet Loss 5. 대회에서의 적용
차원이 증가하면서 학습 데이터 수가 차원의 수보다 적어져서 성능이 저하되는 현상이다. 차원이 증가할 수록 변수는 증가하는데, 개별 차원 내에서 학습할 데이터 수는 적어진다. ( 주의: 변수가 증가한다고 반드시 차원의 저주가 발생하는 것이 아니라 관측치보다 변수의 개수가 더 많아지는 경우에 차원의 저주가 발생하는 것이다)
아래 사진처럼 차원이 증가할 수록 빈 공간이 많아진다. 그리고 빈공간이 생겼다는 것은 컴퓨터 상으로 0으로 채워졌다, 정보가 없다는 뜻이다. 정보가 적어지면 당연하게도 모델 성능이 저하된다.
좌측 사진은 문제 상황이고 우측 사진은 해결책이다.
[해결] 고차원 데이터에서 고수준의 유사도를 다루기 위해서는 저차원의 manifold를 찾아야한다.
이를 위해서는 dimension reduction 방식이 필요하고, deep neural network가 바로 이에 해당한다.
즉, deep neural network를 이용해 적정한 manifold를 찾아 metric learning을 연구하는 Deep metric Learning이 활용되는 것이다.
manifold란
데이터가 있는 공간을 일컫는다.
데이터는 다양한 차원에서 존재할 수 있는데, 이 차원을 축소시켜 작은 차원에서 보려는 시도들이 있기도 하다.
Deep metric Learning 설계
de : embedding 간의 거리 함수
x, y : input data
f : embedding 함수일 때, Classification task에서의 목표는 아래와 같다.
x와 y가 동일 클래스일 때 de(f(x),f(y)) 를 작게 만든다
x와 y가 다른 클래스일 경우, de(f(x),f(y)) 를 크게 만든다
Deep Metric Learning 의 대표적인 loss function은 Contrastive Loss와 Triplet Loss이다.
Contrastive Loss란
Contrastive Learning에서도 사용되는 loss 함수로, 비슷한 이미지 pair와 비슷하지 않은 이미지 pair에 대한 term을 모두 고려 한다. 아래와 같이 Positive pair에 대해서는 임베딩 거리가 가까워지도록, Negative pair에 대해서는 임베딩 거리가 멀어지도록 한다. 단, margin m이상의 차이가 날 경우 손실함수를 부여하지 않게끔 해서 거리를 최대한 m 이상으로 유지한다.
이렇게 정의한 loss function이 Contrastive loss이고 수식은 아래와 같다. y가 1이면 positive pair로 좌측 식만, y가 0이면 negative pair로 우측 식만 적용된다.
정리하자면 contrastive loss를 통해 학습한다는 것은 두 데이터가 negative pair일때 margin 이상의 거리를 학습하도록 한다.
-------------------------------------------------------------------------------------------------------------------------------------------- Siamese Network(샴 네트워크)에서 Contrastive Loss를 활용한다. Siamese Network는 두 개의 동일한 하위 네트워크(shared parameters)로 구성된 대칭 신경망 아키텍처이다.
Triplet Loss란
Triplet Network는 3개의 동일한 하위 네트워크로 구성되었다. Anchor, positive, negative 라는 3개의 input을 입력받는다.
a : 입력 p: positive n: negative m: margin
마찬가지로 비슷한 데이터 간 거리를 좁히고, 다른 데이터 간 거리를 넓힌다. Contrastive loss와 비슷하게 m을 설정하였는데 positive distance와 negative distance의 거리 차가 m 이상일 경우 loss에 반영하지 않는다.