AI TECH/TIL

Contrastive Learning

prefer_all 2022. 11. 3. 13:33

대회를 진행하면서 멘토님께서 Contrastive Learning을 위해서는 Negative Sample을 잘 설계하는 것이 중요하다고 하셨다. Contrastive Learning에 대해서 잘 몰랐던 터라 이를 정리해봤다.

 

<목차>
1. Contrastive Learning
2. Contrastive Learning VS Similarity Learning
3. Distance Metric Learning
- Metric
4. Deep Metric Learning
- 등장 배경
- 설계
- Contrastive Loss, Triplet Loss
5. 대회에서의 적용

 

Contrastive Learning

self-superviesed learning에 사용되는 접근법 중 하나로, input sampel 간의 비교를 통해 학습하는 것이다.

Contrastive Learning은 Metric Learning의 일종이다.

 

딥러닝 모델이 이미지를 잘 이해하고 있다는 것을 우리는 어떻게 알 수 있을까?
비슷한 데이터는 유사도가 높아지고 다른 데이터는 유사도가 낮아지게끔 모델이 학습된다면
모델이 이미지를 잘 이해한다고 판단할 수 있다.

 

Contrastive Learning에서는 positive pair와 negative pair을 사용해서 대상들의 차이를 더 명확하게 학습할 수 있도록 한다.

 

positive pair, negative pair이란

같은 이미지에서 나온 패치를 positive pair라고 하고, 다른 이미지에서 나온 패치는 negative pair라고 한다.
직관적으로도, 원본 이미지에 augmentation을 주면 서로 유사하기 때문에 positive pair라고 이해할 수 있다.

- 우리는 NLP task에서 contrastive learning을 적용하려고 했지만 주로 CV task에서 사용되는 것 같다.

 

Contrastive Learning의 예시를 살펴보자.

아래 그림과 같이 입력으로 (Input(Query), positive, negative)를 받아서 데이터 간 거리를 조절하는 함수 f를 학습할 수 있다.

Query(A)는 우리가 판단하고자 하는 이미지, positive(B)는 그 이미지와 유사한 이미지, negative(C)는 그 이미지와 유사하지 않은 이미지이다.


Contrastive Learning VS Similarity Learning

contrastive learning에 대해 찾아보다보면 similarity learning이 꼭 언급된다.

similarity learning은 supervised learning(지도 학습)에서 사용된다.

similarity learning은 유사도를 반환하는 함수를 학습한다. 


Distance metric Learning

그렇다면 두 데이터 간의 유사도(거리)는 어떻게 측정할 수 있을까?

유사도를 판단하는 기준은 여러 가지가 있다.

 

metric이란

두 객체 간 거리를 정량화하는 함수이다.
Contrastive Learning에서도 어떤 metric을 선택하느냐에 따라 다양한 실험이 가능하다. 
metric의 예시는 아래와 같다.

 

metric learning은 데이터 간 유사도를 수량화하는 'metric(distance) function'을 학습하는 것을 목표로 한다.

학습된 metric function d는 아래와 같다. de는 미리 정의한 embedding 간의 거리 함수이다.

함수 f는 original feature space를 embedding space로 mapping한다. embedding space에서 거리 계산이 쉽기 때문이다. 


Deep metric Learning  등장 배경

[문제] 데이터가 고차원이라면 두 데이터 간 유사도를 비교하는 게 매우 어려워진다. 

즉, 의미적으로 가깝다고 생각되는 고차원 공간에서의 두 데이터의 실제 Euclidean distance는 먼 경우가 많다.

왜냐하면 사람은 비교적 저차원의 공간(의미적인 공간)에서의 거리 개념을 따지는 것과 다르게,

모델은 의미 없는 포인트들만을 따지고, 의미 있는 manifold를 찾지 못하기 때문이다.

 

차원의 저주

차원이 증가하면서 학습 데이터 수가 차원의 수보다 적어져서 성능이 저하되는 현상이다.
차원이 증가할 수록 변수는 증가하는데, 개별 차원 내에서 학습할 데이터 수는 적어진다.
( 주의: 변수가 증가한다고 반드시 차원의 저주가 발생하는 것이 아니라
   관측치보다 변수의 개수가 더 많아지는 경우에 차원의 저주가 발생하는 것이다)

아래 사진처럼 차원이 증가할 수록 빈 공간이 많아진다.
그리고 빈공간이 생겼다는 것은 컴퓨터 상으로 0으로 채워졌다, 정보가 없다는 뜻이다.
정보가 적어지면 당연하게도 모델 성능이 저하된다.

 

 

좌측 사진은 문제 상황이고 우측 사진은 해결책이다.

 

[해결] 고차원 데이터에서 고수준의 유사도를 다루기 위해서는 저차원의 manifold를 찾아야한다.

이를 위해서는 dimension reduction 방식이 필요하고, deep neural network가 바로 이에 해당한다.

즉, deep neural network를 이용해 적정한 manifold를 찾아 metric learning을 연구하는 Deep metric Learning이 활용되는 것이다.

manifold

데이터가 있는 공간을 일컫는다.

데이터는 다양한 차원에서 존재할 수 있는데,
이 차원을 축소시켜 작은 차원에서 보려는 시도들이 있기도 하다. 

Deep metric Learning  설계

de : embedding 간의 거리 함수

x, y : input data

f : embedding 함수일 때, Classification task에서의 목표는 아래와 같다.

 

  • 와 가 동일 클래스일 때 를 작게 만든다
  • 와 가 다른 클래스일 경우, 를 크게 만든다

 

Deep Metric Learning 의 대표적인 loss function은 Contrastive Loss와 Triplet Loss이다.

 

Contrastive Loss

Contrastive Learning에서도 사용되는 loss 함수로, 비슷한 이미지 pair와 비슷하지 않은 이미지 pair에 대한 term을 모두 고려 한다.
아래와 같이 Positive pair에 대해서는 임베딩 거리가 가까워지도록, Negative pair에 대해서는 임베딩 거리가 멀어지도록 한다. 단, margin m이상의 차이가 날 경우 손실함수를 부여하지 않게끔 해서 거리를 최대한 m 이상으로 유지한다.


이렇게 정의한 loss function이 Contrastive loss이고 수식은 아래와 같다. y가 1이면 positive pair로 좌측 식만, y가 0이면 negative pair로 우측 식만 적용된다.


정리하자면 contrastive loss를 통해 학습한다는 것은 두 데이터가 negative pair일때 margin 이상의 거리를 학습하도록 한다.

--------------------------------------------------------------------------------------------------------------------------------------------
Siamese Network(샴 네트워크)에서 Contrastive Loss를 활용한다.

Siamese Network는 두 개의 동일한 하위 네트워크(shared parameters)로 구성된 대칭 신경망 아키텍처이다.

Triplet Loss

Triplet Network는 3개의 동일한 하위 네트워크로 구성되었다.
Anchor, positive, negative 라는 3개의 input을 입력받는다.

a : 입력
p: positive
n: negative
m: margin

마찬가지로 비슷한 데이터 간 거리를 좁히고, 다른 데이터 간 거리를 넓힌다. Contrastive loss와 비슷하게 m을 설정하였는데 positive distance와 negative distance의 거리 차가 m 이상일 경우 loss에 반영하지 않는다.

 

참고   참고2


적용

데이터셋 설계

아래와 같이 기본 baseline이 있다면,

 

 

학습 단계를 두 단계로 나눠서 1차 학습 때 positive, negative pair을 활용해야 한다.

 

 

 

두 가지 방법으로 train 데이터셋을 설계했다

1. binary label이 1인 데이터(유사도가 높은 sentence 1과 sentence2)를 추출한다. 

다음 행 문장(main-pos-neg)으로 negative pair을 생성한다. 

이해를 돕기 위해 내가 만든 예제일 뿐 실제 train 셋이 아님

 

2. negative sample을 잘 설계한다는 것은 헷갈릴 만한 pair을 구성한다는 것이다.

예를 들어 3점을 기준으로 2.8, 2.5점 데이터들은 hard negative 데이터이다. 

binary label이 0인 데이터들 중 2.5~2.9점 데이터들을 negative pair로 이용하고

positive pair은 main sentence를 그대로 복사해보자. (혹은 main sentence에 약간의 변화를 준다. 이를 테면 조사, AEDA 등의 증강 등)

def contrastive_data(data_path):
    """
    binary-label==0(neg pair)인 문장을 기준으로 contrastive learning dataframe을 생성
    main sent의 pos pair는 main sent와 동일한 문장으로 사용
    """
    train = pd.read_csv(data_path)
    neg_pair = train[train["binary-label"] == 0]
    final_df = pd.DataFrame()

    for row in neg_pair.iterrows():

        new_df_row = pd.DataFrame(
            {
                "main_sentence": [row[1]["sentence_1"]],
                "pos_sentence": [row[1]["sentence_1"]],
                "neg_sentence": [row[1]["sentence_2"]],
            }
        )
        final_df = pd.concat([final_df, new_df_row], ignore_index=True)

    return final_df
    
# main에서 cl_train_df = contrastive_data(train_data_path) 이런 식으로 사용

 

학습 두 단계로 나누기

1단계에서 사용할 ElectraModel, 2단계에서 사용할 ContrastiveModel을 위해

Dataset도 ElectraDataset, ContrastiveDataset / DataLoader도 ElectraDataLoader, ContrastiveDataLoader로 나누어 구현했다.

 

## contrastive_train.py의 main

# 주목할 imports
from models.contrastive_model import (
    ContrastiveElectraForSequenceClassification,
    ContrastiveLearnedElectraModel,
    ContrastiveModel,
)
from dataloader.dataloader import ContrastiveDataLoader, ElectraDataLoader
from trainer.trainer import ContrastiveTrainer, Trainer

# ***** 1차 학습 *****
dataloader = ContrastiveDataLoader(
    config.model.name,
    config.train.batch_size,
    config.data.shuffle,
    config.path.cl_train_path,
    config.path.cl_dev_path,
)
model = ContrastiveModel(config) # get Contrastive model
trainer = ContrastiveTrainer(config, wandb_logger) # get Contrastive train
trainer.fit(model=model, datamodule=dataloader) # Contrstive training start
torch.save(model, "contrastive_trained.pt")

# ***** 2차 학습 *****
dataloader = ElectraDataLoader( 
    config.model.name,
    config.train.batch_size,
    config.data.shuffle,
    config.path.train_path,
    config.path.dev_path,
    config.path.test_path,
    config.path.predict_path,
)
model = ContrastiveLearnedElectraModel(config)
trainer = Trainer(config, wandb_logger)
trainer.fit(model=model, datamodule=dataloader)
torch.save(model, "contrastive_trained_2.pt")

 

 

~땡스투 별희 웅니~

 

 

 

 

'AI TECH > TIL' 카테고리의 다른 글

week7,8 면접 준비  (0) 2022.12.14
[TIL] AI와 저작권법  (0) 2022.11.10
STS 대회 에러 해결법  (0) 2022.11.02
Wandb Sweep  (0) 2022.11.02
[P stage] Week6 Today I Learn  (0) 2022.11.02