Sweep이란
Sweep은 하이퍼 파라미터 최적화를 위한 도구이다.
기존 템플릿에 sweep을 추가했다는 건 이 포스팅에서 언급만 하고 넘어갔는데
yaml + OmegaConf + shell + wandb을 활용한 과정을 상세히 정리하고자 한다.
위 도구들을 사용하면 모델 학습과 실험을 간단하고 편리하게 관리할 수 있다.
새롭게 변경된 프로젝트 구조는 다음과 같다.
📁pytorch-template/
│
├── train.py
├── train.sh
├── sweep.sh
├── inference.py
├── requirements.txt
│
├── 📁data/
│
├── 📁dataloader/
│ └── data_loaders.py
│
├── 📁models/
│ ├── model.py
│ ├── optimizer.py
│ └── loss_function.py
│
├── 📁configs/
│ └── base_config.yaml
│ └── sweep_config.py
│
├── 📁trainer
│ └── trainer.py
│
└── 📁notebook/
파일 설명
# train.sh
python3 train.py --config base_config
# sweep.sh
wandb sweep configs/sweep_config.yaml
base_config.yaml
베이스라인에 주어진 인자들을 취향대로 구조화해서 옮겨놓은 yaml 파일이다.
path:
train_path: ../data/train.csv
dev_path: ../data/dev.csv
test_path: ../data/dev.csv
predict_path: ../data/test.csv
data:
shuffle: True
augmentation: # adea, bt 등등
model:
model_name: klue/roberta-small
saved_name: base_model
train:
seed: 42
gpus: 1
batch_size: 16
max_epoch: 1
learning_rate: 1e-5
logging_step: 1
모델을 electra로 바꾼 파일은 electra_config.yaml, 여기에 data 증강을 추가한 파일은 electra_adea_config.yaml 등으로 변형을 줄 수도 있다.
train.py
1. 라이브러리 import
pip install wandb
pip install omegaconf
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger
2. main에서 인자로 yaml 파일명을 입력받는다
- OmegaConf로 yaml 파일을 로드하게 되면 yaml에 저장된 값들에 .으로 편리하게 접근 가능하다
- wandb logger를 trainer에 넘겨주면 wandb에 학습 결과가 자동으로 기록된다!
(어떤 정보들이 기록되는지는 Model class에서 무엇을 로깅해주는지에 달려있다)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='')
args, _ = parser.parse_known_args()
cfg = OmegaConf.load(f'./config/{args.config}.yaml')
wandb.login()
wandb_logger = WandbLogger(name='first-logger', project='nlp-12')
dataloader = Dataloader(cfg.model.model_name, cfg.train.batch_size, cfg.data.shuffle, cfg.path.train_path, cfg.path.dev_path,
cfg.path.test_path, cfg.path.predict_path)
model = Model(cfg)
trainer = pl.Trainer(gpus=cfg.train.gpus, max_epochs=cfg.train.max_epoch, logger=wandb_logger, log_every_n_steps=cfg.train.logging_step)
trainer.fit(model=model, datamodule=dataloader)
trainer.test(model=model, datamodule=dataloader)
torch.save(model, f'{cfg.model.saved_name}.pt')
Model.py
- 모델 초기화할 때도 lr, model_name 등등 일일이 넘겨받을 필요 없이 이제 config만 넘겨받으면 된다.
- 이렇게 전체 config를 넘겨받게 되면 나중에 wandb에서 해당 학습에 사용된 전체 config를 확인할 수 있어서 어떤 run이 어떤 설정 사용했는지 헷갈릴 때 유용하다.
class Model(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.config = config
self.model_name = config.model.name
self.lr = config.train.learning_rate
self.model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=self.model_name, num_labels=1
)
self.model.resize_token_embeddings(50137)
self.loss_func = get_loss_func(config)
How to Run (Shell)
하이퍼 파라미터 튜닝 없이 train하려면 아래와 같이 커맨드 창에 입력하면 된다.
$ sh train.sh
sweep을 사용해 하이퍼 파라미터를 튜닝하라면 아래의 명령어를 입력하면 된다
$ sh sweep.sh
# Launch agents
## bayes나 random 탐색은 프로세스를 직접 종료하기 전까지 계속 탐색하므로
## LIMIT_NUM으로 학습 횟수를 제한할 수 있다.
$ wandb agent --count [LIMIT_NUM] [SWEEPID]
만약 permission denied 문제가 발생한다면
chmod +x train.sh
+
train_many.sh
여러 config에 대한 실험을 돌릴 수 있다.
CONFIGS=("base_config", "roberta_config", "electra_config")
for (( i=0; i<3; i++ ))
do
python3 train.py \
—config ${CONFIGS[$i]} \
done
'AI TECH > TIL' 카테고리의 다른 글
Contrastive Learning (0) | 2022.11.03 |
---|---|
STS 대회 에러 해결법 (0) | 2022.11.02 |
[P stage] Week6 Today I Learn (0) | 2022.11.02 |
[실습] Week5 Today I Learn (0) | 2022.10.21 |
[실습] Week4 Today I Learn (0) | 2022.10.19 |