AI TECH/TIL

Wandb Sweep

prefer_all 2022. 11. 2. 14:29

Sweep이란

 

Sweep은 하이퍼 파라미터 최적화를 위한 도구이다.

 

sweep을 사용하면 각 hyperparameter의 중요도나 관계성을 알아보기도 쉽다

 

기존 템플릿에 sweep을 추가했다는 건 이 포스팅에서 언급만 하고 넘어갔는데 

yaml + OmegaConf + shell + wandb을 활용한 과정을 상세히 정리하고자 한다.

위 도구들을 사용하면 모델 학습과 실험을 간단하고 편리하게 관리할 수 있다.

새롭게 변경된 프로젝트 구조는 다음과 같다.

 

📁pytorch-template/

├── train.py

├── train.sh
├── sweep.sh
├── inference.py
├── requirements.txt 

├── 📁data/

├── 📁dataloader/ 
│ └── data_loaders.py 

├── 📁models/
│ ├── model.py 
│ ├── optimizer.py
│ └── loss_function.py 

├── 📁configs/ 
│ └── base_config.yaml
│ └── sweep_config.py 

├── 📁trainer
│ └── trainer.py 

└── 📁notebook/

파일 설명

# train.sh
python3 train.py --config base_config
# sweep.sh
wandb sweep configs/sweep_config.yaml

 

base_config.yaml

베이스라인에 주어진 인자들을 취향대로 구조화해서 옮겨놓은 yaml 파일이다.

path:
    train_path: ../data/train.csv
    dev_path: ../data/dev.csv
    test_path: ../data/dev.csv
    predict_path: ../data/test.csv

data:
    shuffle: True
    augmentation: # adea, bt 등등
    
model:
    model_name: klue/roberta-small
    saved_name: base_model

train:
    seed: 42
    gpus: 1
    batch_size: 16
    max_epoch: 1
    learning_rate: 1e-5
    logging_step: 1

모델을 electra로 바꾼 파일은 electra_config.yaml, 여기에 data 증강을 추가한 파일은 electra_adea_config.yaml 등으로 변형을 줄 수도 있다.

 

 

train.py

1. 라이브러리 import

pip install wandb
pip install omegaconf

from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

2. main에서 인자로 yaml 파일명을 입력받는다

OmegaConf로 yaml 파일을 로드하게 되면 yaml에 저장된 값들에 .으로 편리하게 접근 가능하다

-  wandb logger를 trainer에 넘겨주면 wandb에 학습 결과가 자동으로 기록된다!

    (어떤 정보들이 기록되는지는 Model class에서 무엇을 로깅해주는지에 달려있다)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='')
    args, _ = parser.parse_known_args()

    cfg = OmegaConf.load(f'./config/{args.config}.yaml')
        
    wandb.login()
    wandb_logger = WandbLogger(name='first-logger', project='nlp-12')

    dataloader = Dataloader(cfg.model.model_name, cfg.train.batch_size, cfg.data.shuffle, cfg.path.train_path, cfg.path.dev_path,
                            cfg.path.test_path, cfg.path.predict_path)
    model = Model(cfg)

    trainer = pl.Trainer(gpus=cfg.train.gpus, max_epochs=cfg.train.max_epoch, logger=wandb_logger, log_every_n_steps=cfg.train.logging_step)

    trainer.fit(model=model, datamodule=dataloader)
    trainer.test(model=model, datamodule=dataloader)

    torch.save(model, f'{cfg.model.saved_name}.pt')

 

 

Model.py

-  모델 초기화할 때도 lr, model_name 등등 일일이 넘겨받을 필요 없이 이제 config만 넘겨받으면 된다.

-  이렇게 전체 config를 넘겨받게 되면 나중에 wandb에서 해당 학습에 사용된 전체 config를 확인할 수 있어서 어떤 run이 어떤 설정 사용했는지 헷갈릴 때 유용하다.

class Model(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()

        self.config = config
        self.model_name = config.model.name
        self.lr = config.train.learning_rate

        self.model = AutoModelForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path=self.model_name, num_labels=1
        )

        self.model.resize_token_embeddings(50137)
        self.loss_func = get_loss_func(config)


How to Run (Shell)

하이퍼 파라미터 튜닝 없이 train하려면 아래와 같이 커맨드 창에 입력하면 된다.

$ sh train.sh

 

sweep을 사용해 하이퍼 파라미터를 튜닝하라면 아래의 명령어를 입력하면 된다

$ sh sweep.sh

# Launch agents
## bayes나 random 탐색은 프로세스를 직접 종료하기 전까지 계속 탐색하므로 
## LIMIT_NUM으로 학습 횟수를 제한할 수 있다.
$ wandb agent --count [LIMIT_NUM] [SWEEPID]

 

만약 permission denied 문제가 발생한다면 

chmod +x train.sh

 

 

 

+

train_many.sh

여러 config에 대한 실험을 돌릴 수 있다.

CONFIGS=("base_config", "roberta_config", "electra_config")

for (( i=0; i<3; i++ ))
do
    python3 train.py \
        —config ${CONFIGS[$i]} \
done

 

 

출처

 

'AI TECH > TIL' 카테고리의 다른 글

Contrastive Learning  (0) 2022.11.03
STS 대회 에러 해결법  (0) 2022.11.02
[P stage] Week6 Today I Learn  (0) 2022.11.02
[실습] Week5 Today I Learn  (0) 2022.10.21
[실습] Week4 Today I Learn  (0) 2022.10.19