Data를 불러오고 나선 Training 코드를 짜면 된다.
pytorch-lightning(이하 pl)의 training code는 pytorch보다 더 간결하고 객체지향적이다.
왜냐하면 train, validation, test 코드를 합쳐서 관리하며, 각 기능들에 대해서도 다 함수화 되어 있기 때문이다.
pl의 training code는 크게 datamodule과 model, train으로 나눌 수 있다.
먼저 Datamodule에는 pytorch에서 사용하던 Dataset class와 Dataload class를 Datamodule에 합쳐서 관리하며,
train, validation, test 각각 관리할 수 있다.
그리고 model class에는 모델 구조, forward, loss, 그리고 train, validation, test 각각에 대하여 정의할 수 있다.
마지막으로 train에서는 앞서 정의한 두 클래스를 Trainer라는 클래스에 넣어서 학습시키면 된다.
각 코드는 다음과 같다.
Training/pl_datamodule.py
"""Cifar10 data module."""
import os
import pytorch_lightning as pl
import webdataset as wds
from torch.utils.data import DataLoader
from torchvision import transforms
class CIFAR10DataModule(
pl.LightningDataModule
): # pylint: disable=too-many-instance-attributes
"""Data module class."""
def __init__(self, **kwargs):
"""Initialization of inherited lightning data module."""
super(
CIFAR10DataModule, self
).__init__() # pylint: disable=super-with-arguments
self.train_dataset = None
self.valid_dataset = None
self.test_dataset = None
self.train_data_loader = None
self.val_data_loader = None
self.test_data_loader = None
self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
self.valid_transform = transforms.Compose(
[
transforms.ToTensor(),
self.normalize,
]
)
self.train_transform = transforms.Compose(
[
transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
self.normalize,
]
)
self.args = kwargs
def prepare_data(self):
"""Implementation of abstract class."""
@staticmethod
def get_num_files(input_path):
"""Gets num files.
Args:
input_path : path to input
"""
return len(os.listdir(input_path)) - 1
def setup(self, stage=None):
"""Downloads the data, parse it and split the data into train, test,
validation data.
Args:
stage: Stage - training or testing
"""
data_path = self.args.get("dataset_path", "/pvc/output/processing")
train_base_url = data_path + "/train"
val_base_url = data_path + "/val"
test_base_url = data_path + "/test"
train_count = self.get_num_files(train_base_url)
val_count = self.get_num_files(val_base_url)
test_count = self.get_num_files(test_base_url)
train_url = "{}/{}-{}".format(
train_base_url, "train", "{0.." + str(train_count) + "}.tar"
)
valid_url = "{}/{}-{}".format(
val_base_url, "val", "{0.." + str(val_count) + "}.tar"
)
test_url = "{}/{}-{}".format(
test_base_url, "test", "{0.." + str(test_count) + "}.tar"
)
self.train_dataset = (
wds.WebDataset(train_url, handler=wds.warn_and_continue)
.shuffle(100)
.decode("pil")
.rename(image="ppm;jpg;jpeg;png", info="cls")
.map_dict(image=self.train_transform)
.to_tuple("image", "info")
# .batched(40)
)
self.valid_dataset = (
wds.WebDataset(valid_url, handler=wds.warn_and_continue)
.shuffle(100)
.decode("pil")
.rename(image="ppm", info="cls")
.map_dict(image=self.valid_transform)
.to_tuple("image", "info")
# .batched(20)
)
self.test_dataset = (
wds.WebDataset(test_url, handler=wds.warn_and_continue)
.shuffle(100)
.decode("pil")
.rename(image="ppm", info="cls")
.map_dict(image=self.valid_transform)
.to_tuple("image", "info")
# .batched(20)
)
def create_data_loader(
self, dataset, batch_size, num_workers
): # pylint: disable=no-self-use
"""Creates data loader."""
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
def train_dataloader(self):
"""Train Data loader.
Returns:
output - Train data loader for the given input
"""
self.train_data_loader = self.create_data_loader(
self.train_dataset,
self.args.get("train_batch_size", None),
self.args.get("train_num_workers", 4),
)
return self.train_data_loader
def val_dataloader(self):
"""Validation Data Loader.
Returns:
output - Validation data loader for the given input
"""
self.val_data_loader = self.create_data_loader(
self.valid_dataset,
self.args.get("val_batch_size", None),
self.args.get("val_num_workers", 4),
)
return self.val_data_loader
def test_dataloader(self):
"""Test Data Loader.
Returns:
output - Test data loader for the given input
"""
self.test_data_loader = self.create_data_loader(
self.test_dataset,
self.args.get("val_batch_size", None),
self.args.get("val_num_workers", 4),
)
return self.test_data_loader
Training/pl_model.py
# Pytorch modules
import torch
from torch.nn import functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Pytorch-Lightning
from pytorch_lightning import LightningModule
import torchmetrics
import timm
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
class Classifier(LightningModule):
def __init__(self, **kwargs):
"""method used to define our model parameters"""
super().__init__()
self.args = kwargs
Path("../data/models/plots").mkdir(parents=True, exist_ok=True)
self.model = timm.create_model(
self.args.get("model_name", "resnet50"),
pretrained=True,
num_classes=self.args.get("num_classes", 10),
)
self.width = self.args.get("width", 32)
self.height = self.args.get("height", 32)
# metrics
self.accuracy = torchmetrics.Accuracy()
# optional - save hyper-parameters to self.hparams
# they will also be automatically logged as config parameters in W&B
self.save_hyperparameters()
def forward(self, x):
x = self.model(x)
# x = self.softmax(x)
return x
# convenient method to get the loss on a batch
def loss(self, x, y):
logits = self(x) # this calls self.forward
loss = F.cross_entropy(logits, y)
return logits, loss
def training_step(self, batch, batch_idx):
"""needs to return a loss from a single batch"""
x, y = batch
if batch_idx == 0:
self.reference_image = (x[0]).unsqueeze(
0
)
# self.reference_image.resize((1,1,28,28))
# print("\n\nREFERENCE IMAGE!!!")
# print(self.reference_image.shape)
logits, loss = self.loss(x, y)
preds = torch.argmax(logits, 1)
# Log training loss
self.log("train_loss", loss)
# Log metrics
self.log("train_acc", self.accuracy(preds, y))
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits, loss = self.loss(x, y)
preds = torch.argmax(logits, 1)
self.log("val_loss", loss) # default on val/test is on_epoch only
self.log("val_acc", self.accuracy(preds, y))
return logits
def test_step(self, batch, batch_idx):
x, y = batch
logits, loss = self.loss(x, y)
preds = torch.argmax(logits, 1)
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.log("test_acc", self.accuracy(preds, y), on_step=False, on_epoch=True)
def configure_optimizers(self):
"""defines model optimizer"""
optimizer = Adam(self.parameters(), lr=self.args.get("lr", 0.0001))
scheduler = ReduceLROnPlateau(optimizer, min_lr=1e-7)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"},
}
def makegrid(self, output, numrows): # pylint: disable=no-self-use
"""Makes grids.
Args:
output : Tensor output
numrows : num of rows.
Returns:
c_array : gird array
"""
outer = torch.Tensor.cpu(output).detach()
plt.figure(figsize=(20, 5))
b_array = np.array([]).reshape(0, outer.shape[2])
c_array = np.array([]).reshape(numrows * outer.shape[2], 0)
i = 0
j = 0
while i < outer.shape[1]:
img = outer[0][i]
b_array = np.concatenate((img, b_array), axis=0)
j += 1
if j == numrows:
c_array = np.concatenate((c_array, b_array), axis=1)
b_array = np.array([]).reshape(0, outer.shape[2])
j = 0
i += 1
return c_array
def show_activations(self, x_var):
"""Showns activation
Args:
x_var: x variable
"""
plt.imsave(
f"../data/models/plots/input_{self.current_epoch}_epoch.png",
torch.Tensor.cpu(x_var[0][0]),
)
# logging layer 1 activations
out = self.model.conv1(x_var)
c_grid = self.makegrid(out, 4)
self.logger.experiment.add_image(
"layer 1", c_grid, self.current_epoch, dataformats="HW"
)
plt.imsave(
f"../data/models/plots/activation_{self.current_epoch}_epoch.png", c_grid
)
def training_epoch_end(self, outputs):
"""Training epoch end.
Args:
outputs: outputs of train end
"""
self.show_activations(self.reference_image)
Training/pl_train.py
"""kubeflow pytorch-lightning training script"""
from pathlib import Path
from argparse import ArgumentParser
from Training.pl_model import Classifier
from Training.pl_datamodule import CIFAR10DataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
# Argument parser for user defined paths
parser = ArgumentParser()
# Train hyperparams
parser.add_argument("--model", default="resnet50", type=str, help="model structure name in timm package")
parser.add_argument("--gpus", default=-1, type=int, help="num of gpus")
parser.add_argument("--max_epochs", default=30, type=int, help="training max epochs")
parser.add_argument("--num_classes", default=10, type=int, help="num_classes")
parser.add_argument("--train_batch_size", type=int, default=256, metavar="N", help="batch size / num_gpus")
parser.add_argument("--train_num_workers", type=int, default=10, metavar="N", help="number of workers")
parser.add_argument("--val_batch_size", type=int, default=256, metavar="N", help="batch size / num_gpus")
parser.add_argument("--val_num_workers", type=int, default=10, metavar="N", help="number of workers")
parser.add_argument("--lr", type=float, default=1e-3, metavar="N", help="learning rate")
# log args
parser.add_argument(
"--check_val_every_n_epoch",
type=int,
default=10,
metavar="N",
help="checkpoint period",
)
parser.add_argument(
"--log_every_n_steps",
default=50,
type=int,
help="log every n steps",
)
# container IO
parser.add_argument(
"--checkpoint_dir",
type=str,
default="/train/models",
help="Path to save model checkpoints (default: output/train/models)",
)
parser.add_argument(
"--dataset_path",
type=str,
default="../data",
help="Cifar10 Dataset path (default: ../data)",
)
if __name__ == "__main__":
# parser = pl.Trainer.add_argparse_args(parent_parser=parser)
args = parser.parse_args()
# Enabling Tensorboard Logger, ModelCheckpoint, Earlystopping
lr_logger = LearningRateMonitor()
early_stopping = EarlyStopping(
monitor="val_loss",
mode="min",
patience=10,
verbose=True,
)
checkpoint_callback = ModelCheckpoint(
dirpath=args.checkpoint_dir,
filename="cifar10_{epoch:02d}_{val_loss:.2f}",
save_top_k=3,
verbose=True,
monitor="val_loss",
mode="min",
)
# Creating parent directories
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
# Setting the datamodule specific arguments
datamodule_args = {
"dataset_path": args.dataset_path,
"train_batch_size": args.train_batch_size,
"train_num_workers": args.train_num_workers,
"val_batch_size": args.val_batch_size,
"val_num_workers": args.val_num_workers,
}
datamodule = CIFAR10DataModule(**datamodule_args)
# Initiating the training process
trainer = Trainer(
# logger=wandb_logger, # W&B integration
log_every_n_steps=args.log_every_n_steps, # set the logging frequency
gpus=args.gpus, # use all GPUs
max_epochs=args.max_epochs, # number of epochs
deterministic=True, # keep it deterministic
enable_checkpointing=True,
callbacks=[
# Logger(samples),
early_stopping,
checkpoint_callback,
], # see Callbacks section
precision=16,
check_val_every_n_epoch=args.check_val_every_n_epoch,
strategy="ddp",
auto_scale_batch_size=True,
)
model_config = {
"model_name": args.model,
"num_classes": args.num_classes,
"lr": args.lr,
"width": 32,
"height": 32,
}
# model = trainer.lightning_module
model = Classifier(**model_config)
trainer.fit(model, datamodule)
trainer.test(datamodule=datamodule)
'MLOps' 카테고리의 다른 글
MLOps E2E - 5. Storage : minio (0) | 2022.05.10 |
---|---|
MLOps E2E - 4. Logging : wandb (0) | 2022.05.04 |
MLOps E2E - 3. Pipeline (0) | 2022.04.29 |
MLOps E2E - 2-1. CT : Data Load (kubeflow pipeline) (0) | 2022.04.27 |
MLOps E2E - 1. CI / CD : Github Actions (0) | 2022.04.26 |