MLOps

MLOps E2E - 2-2. CT : Training (kubeflow pipeline)

Hongma 2022. 4. 28. 21:51

 

Data를 불러오고 나선 Training 코드를 짜면 된다.

 

pytorch-lightning(이하 pl)의 training code는 pytorch보다 더 간결하고 객체지향적이다.

왜냐하면 train, validation, test 코드를 합쳐서 관리하며, 각 기능들에 대해서도 다 함수화 되어 있기 때문이다.

 

pl의 training code는 크게 datamodule과 model, train으로 나눌 수 있다.

 

먼저 Datamodule에는 pytorch에서 사용하던 Dataset class와 Dataload class를 Datamodule에 합쳐서 관리하며,

train, validation, test 각각 관리할 수 있다.

 

그리고 model class에는 모델 구조, forward, loss, 그리고 train, validation, test 각각에 대하여 정의할 수 있다.

 

마지막으로 train에서는 앞서 정의한 두 클래스를 Trainer라는 클래스에 넣어서 학습시키면 된다.

 

각 코드는 다음과 같다.

 

Training/pl_datamodule.py

"""Cifar10 data module."""
import os

import pytorch_lightning as pl
import webdataset as wds
from torch.utils.data import DataLoader
from torchvision import transforms


class CIFAR10DataModule(
    pl.LightningDataModule
):  # pylint: disable=too-many-instance-attributes
    """Data module class."""

    def __init__(self, **kwargs):
        """Initialization of inherited lightning data module."""
        super(
            CIFAR10DataModule, self
        ).__init__()  # pylint: disable=super-with-arguments

        self.train_dataset = None
        self.valid_dataset = None
        self.test_dataset = None
        self.train_data_loader = None
        self.val_data_loader = None
        self.test_data_loader = None
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.valid_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                self.normalize,
            ]
        )

        self.train_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                self.normalize,
            ]
        )
        self.args = kwargs

    def prepare_data(self):
        """Implementation of abstract class."""

    @staticmethod
    def get_num_files(input_path):
        """Gets num files.
        Args:
             input_path : path to input
        """
        return len(os.listdir(input_path)) - 1

    def setup(self, stage=None):
        """Downloads the data, parse it and split the data into train, test,
        validation data.
        Args:
            stage: Stage - training or testing
        """

        data_path = self.args.get("dataset_path", "/pvc/output/processing")

        train_base_url = data_path + "/train"
        val_base_url = data_path + "/val"
        test_base_url = data_path + "/test"

        train_count = self.get_num_files(train_base_url)
        val_count = self.get_num_files(val_base_url)
        test_count = self.get_num_files(test_base_url)

        train_url = "{}/{}-{}".format(
            train_base_url, "train", "{0.." + str(train_count) + "}.tar"
        )
        valid_url = "{}/{}-{}".format(
            val_base_url, "val", "{0.." + str(val_count) + "}.tar"
        )
        test_url = "{}/{}-{}".format(
            test_base_url, "test", "{0.." + str(test_count) + "}.tar"
        )

        self.train_dataset = (
            wds.WebDataset(train_url, handler=wds.warn_and_continue)
            .shuffle(100)
            .decode("pil")
            .rename(image="ppm;jpg;jpeg;png", info="cls")
            .map_dict(image=self.train_transform)
            .to_tuple("image", "info")
            # .batched(40)
        )

        self.valid_dataset = (
            wds.WebDataset(valid_url, handler=wds.warn_and_continue)
            .shuffle(100)
            .decode("pil")
            .rename(image="ppm", info="cls")
            .map_dict(image=self.valid_transform)
            .to_tuple("image", "info")
            # .batched(20)
        )

        self.test_dataset = (
            wds.WebDataset(test_url, handler=wds.warn_and_continue)
            .shuffle(100)
            .decode("pil")
            .rename(image="ppm", info="cls")
            .map_dict(image=self.valid_transform)
            .to_tuple("image", "info")
            # .batched(20)
        )

    def create_data_loader(
        self, dataset, batch_size, num_workers
    ):  # pylint: disable=no-self-use
        """Creates data loader."""
        return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

    def train_dataloader(self):
        """Train Data loader.
        Returns:
             output - Train data loader for the given input
        """
        self.train_data_loader = self.create_data_loader(
            self.train_dataset,
            self.args.get("train_batch_size", None),
            self.args.get("train_num_workers", 4),
        )
        return self.train_data_loader

    def val_dataloader(self):
        """Validation Data Loader.
        Returns:
             output - Validation data loader for the given input
        """
        self.val_data_loader = self.create_data_loader(
            self.valid_dataset,
            self.args.get("val_batch_size", None),
            self.args.get("val_num_workers", 4),
        )
        return self.val_data_loader

    def test_dataloader(self):
        """Test Data Loader.
        Returns:
             output - Test data loader for the given input
        """
        self.test_data_loader = self.create_data_loader(
            self.test_dataset,
            self.args.get("val_batch_size", None),
            self.args.get("val_num_workers", 4),
        )
        return self.test_data_loader

 

Training/pl_model.py

# Pytorch modules
import torch
from torch.nn import functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Pytorch-Lightning
from pytorch_lightning import LightningModule

import torchmetrics
import timm
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path


class Classifier(LightningModule):
    def __init__(self, **kwargs):
        """method used to define our model parameters"""
        super().__init__()

        self.args = kwargs
        Path("../data/models/plots").mkdir(parents=True, exist_ok=True)

        self.model = timm.create_model(
            self.args.get("model_name", "resnet50"),
            pretrained=True,
            num_classes=self.args.get("num_classes", 10),
        )

        self.width = self.args.get("width", 32)
        self.height = self.args.get("height", 32)

        # metrics
        self.accuracy = torchmetrics.Accuracy()

        # optional - save hyper-parameters to self.hparams
        # they will also be automatically logged as config parameters in W&B
        self.save_hyperparameters()

    def forward(self, x):
        x = self.model(x)
        # x = self.softmax(x)
        return x

    # convenient method to get the loss on a batch
    def loss(self, x, y):
        logits = self(x)  # this calls self.forward
        loss = F.cross_entropy(logits, y)
        return logits, loss

    def training_step(self, batch, batch_idx):
        """needs to return a loss from a single batch"""
        x, y = batch
        if batch_idx == 0:
            self.reference_image = (x[0]).unsqueeze(
                0
            )
            # self.reference_image.resize((1,1,28,28))
            # print("\n\nREFERENCE IMAGE!!!")
            # print(self.reference_image.shape)
        logits, loss = self.loss(x, y)
        preds = torch.argmax(logits, 1)

        # Log training loss
        self.log("train_loss", loss)

        # Log metrics
        self.log("train_acc", self.accuracy(preds, y))

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self.loss(x, y)
        preds = torch.argmax(logits, 1)

        self.log("val_loss", loss)  # default on val/test is on_epoch only
        self.log("val_acc", self.accuracy(preds, y))

        return logits

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self.loss(x, y)
        preds = torch.argmax(logits, 1)

        self.log("test_loss", loss, on_step=False, on_epoch=True)
        self.log("test_acc", self.accuracy(preds, y), on_step=False, on_epoch=True)

    def configure_optimizers(self):
        """defines model optimizer"""
        optimizer = Adam(self.parameters(), lr=self.args.get("lr", 0.0001))
        scheduler = ReduceLROnPlateau(optimizer, min_lr=1e-7)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"},
        }

    def makegrid(self, output, numrows):  # pylint: disable=no-self-use
        """Makes grids.

        Args:
             output : Tensor output
             numrows : num of rows.
        Returns:
             c_array : gird array
        """
        outer = torch.Tensor.cpu(output).detach()
        plt.figure(figsize=(20, 5))
        b_array = np.array([]).reshape(0, outer.shape[2])
        c_array = np.array([]).reshape(numrows * outer.shape[2], 0)
        i = 0
        j = 0
        while i < outer.shape[1]:
            img = outer[0][i]
            b_array = np.concatenate((img, b_array), axis=0)
            j += 1
            if j == numrows:
                c_array = np.concatenate((c_array, b_array), axis=1)
                b_array = np.array([]).reshape(0, outer.shape[2])
                j = 0

            i += 1
        return c_array

    def show_activations(self, x_var):
        """Showns activation
        Args:
             x_var: x variable
        """
        plt.imsave(
            f"../data/models/plots/input_{self.current_epoch}_epoch.png",
            torch.Tensor.cpu(x_var[0][0]),
        )

        # logging layer 1 activations
        out = self.model.conv1(x_var)
        c_grid = self.makegrid(out, 4)
        self.logger.experiment.add_image(
            "layer 1", c_grid, self.current_epoch, dataformats="HW"
        )

        plt.imsave(
            f"../data/models/plots/activation_{self.current_epoch}_epoch.png", c_grid
        )

    def training_epoch_end(self, outputs):
        """Training epoch end.

        Args:
             outputs: outputs of train end
        """
        self.show_activations(self.reference_image)

 

Training/pl_train.py

"""kubeflow pytorch-lightning training script"""
from pathlib import Path
from argparse import ArgumentParser
from Training.pl_model import Classifier
from Training.pl_datamodule import CIFAR10DataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)

# Argument parser for user defined paths
parser = ArgumentParser()

# Train hyperparams
parser.add_argument("--model", default="resnet50", type=str, help="model structure name in timm package")
parser.add_argument("--gpus", default=-1, type=int, help="num of gpus")
parser.add_argument("--max_epochs", default=30, type=int, help="training max epochs")
parser.add_argument("--num_classes", default=10, type=int, help="num_classes")
parser.add_argument("--train_batch_size", type=int, default=256, metavar="N", help="batch size / num_gpus")
parser.add_argument("--train_num_workers", type=int, default=10, metavar="N", help="number of workers")
parser.add_argument("--val_batch_size", type=int, default=256, metavar="N", help="batch size / num_gpus")
parser.add_argument("--val_num_workers", type=int, default=10, metavar="N", help="number of workers")
parser.add_argument("--lr", type=float, default=1e-3, metavar="N", help="learning rate")

# log args
parser.add_argument(
    "--check_val_every_n_epoch",
    type=int,
    default=10,
    metavar="N",
    help="checkpoint period",
)

parser.add_argument(
    "--log_every_n_steps",
    default=50,
    type=int,
    help="log every n steps",
)

# container IO
parser.add_argument(
    "--checkpoint_dir",
    type=str,
    default="/train/models",
    help="Path to save model checkpoints (default: output/train/models)",
)

parser.add_argument(
    "--dataset_path",
    type=str,
    default="../data",
    help="Cifar10 Dataset path (default: ../data)",
)



if __name__ == "__main__":
    # parser = pl.Trainer.add_argparse_args(parent_parser=parser)
    args = parser.parse_args()

    # Enabling Tensorboard Logger, ModelCheckpoint, Earlystopping

    lr_logger = LearningRateMonitor()
    early_stopping = EarlyStopping(
        monitor="val_loss",
        mode="min",
        patience=10,
        verbose=True,
    )
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.checkpoint_dir,
        filename="cifar10_{epoch:02d}_{val_loss:.2f}",
        save_top_k=3,
        verbose=True,
        monitor="val_loss",
        mode="min",
    )

    # Creating parent directories
    Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)

    # Setting the datamodule specific arguments
    datamodule_args = {
        "dataset_path": args.dataset_path,
        "train_batch_size": args.train_batch_size,
        "train_num_workers": args.train_num_workers,
        "val_batch_size": args.val_batch_size,
        "val_num_workers": args.val_num_workers,
    }

    datamodule = CIFAR10DataModule(**datamodule_args)

    # Initiating the training process
    trainer = Trainer(
        # logger=wandb_logger,    # W&B integration
        log_every_n_steps=args.log_every_n_steps,  # set the logging frequency
        gpus=args.gpus,  # use all GPUs
        max_epochs=args.max_epochs,  # number of epochs
        deterministic=True,  # keep it deterministic
        enable_checkpointing=True,
        callbacks=[
            # Logger(samples),
            early_stopping,
            checkpoint_callback,
        ],  # see Callbacks section
        precision=16,
        check_val_every_n_epoch=args.check_val_every_n_epoch,
        strategy="ddp",
        auto_scale_batch_size=True,
    )

    model_config = {
        "model_name": args.model,
        "num_classes": args.num_classes,
        "lr": args.lr,
        "width": 32,
        "height": 32,
    }
    # model = trainer.lightning_module
    model = Classifier(**model_config)

    trainer.fit(model, datamodule)
    trainer.test(datamodule=datamodule)

 

 

 

'MLOps' 카테고리의 다른 글

MLOps E2E - 5. Storage : minio  (0) 2022.05.10
MLOps E2E - 4. Logging : wandb  (0) 2022.05.04
MLOps E2E - 3. Pipeline  (0) 2022.04.29
MLOps E2E - 2-1. CT : Data Load (kubeflow pipeline)  (0) 2022.04.27
MLOps E2E - 1. CI / CD : Github Actions  (0) 2022.04.26