상세 컨텐츠

본문 제목

PyTorch (1) Dataset, DataLoader

프로그래밍/딥러닝

by 릿카。 2024. 7. 29. 10:47

본문

* 필자는 전문 코더가 아니며, 따라서 일부 틀린설명이 있을 수 있음. 다만 딥러닝 학습이라는 큰 틀에서 핵심적인 내용을 서술하려고 함. 

 

PyTorch를 이용해 딥러닝을 구현할때, 가장 먼저 마주하는 class는 아마 Dataset과 DataLoader일 것이다. 딥러닝에서는 빅데이터가 필요한데, 이 빅데이터를 한꺼번에 전부 메모리 RAM에 올리는건 비효율적이기 때문에, 위의 두 class가 구현되어있다. 이 두 클래스는 데이터 자체라기보단, 데이터를 불러오는 '방식'을 정의하는 class이다.

 

Dataset : Iter 방식으로 data 반환 : x, y 경로라던지 위치를 지정해주면, 거기서 iterable처럼 x[i], y[i]를 반환해주는 class

DataLoader : 배치, 병렬프로세싱, 셔플, 샘플러 기능 등을 구현해서 {x[i], y[i]}를 training에 원하는대로 불러올 수 있게하는 class

 

다음은 pytorch 깃헙에서 가져온 코드임(https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataset.py)

class Dataset(Generic[_T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`. Subclasses could also
    optionally implement :meth:`__getitems__`, for speedup batched samples
    loading. This method accepts list of indices of samples of batch and returns
    list of samples.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs an index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> _T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

    # def __getitems__(self, indices: List) -> List[_T_co]:
    # Not implemented to prevent false-positives in fetcher check in
    # torch.utils.data._utils.fetch._MapDatasetFetcher

    def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

__len__ 함수와 __getitem__이라는 함수를 사용자가 직접 정의해줘야 하는데, 전자는 dataloader에서 sampler라는 기능을 사용하기 위해 필요한 함수이고, 후자는 데이터를 로드해오는 직접적 기능을 정의하는 함수이다.

 

PyTorch 공식 페이지에 예시 코드가 있다 (https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

앞서 언급했듯, __getitem__ 함수가 x[i], y[i]를 반환하는 함수이다.

 

이렇게 given i에 대해 x[i], y[i]를 반환해주는 Dataset 클래스를 정의했다면, 이들의 집합인 {x[i], y[i]}를 구성해서 딥러닝 training에 써먹을 수 있게. 해주는 DataLoader 함수가 필요하다. 이 DataLoader 함수는 다음과 같은 맥락에서 사용된다 : 

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
		# Forward pass
        outputs = model(images)
        # ~

# Testing the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        # ~

 

여하튼, DataLoader 클래스는 : (https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataloader.py) 여기를 참고하자. 사용방법은 

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

이런식이고,

"""
Args:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
        sampler (Sampler or Iterable, optional): defines the strategy to draw
            samples from the dataset. Can be any ``Iterable`` with ``__len__``
            implemented. If specified, :attr:`shuffle` must not be specified.
        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
            returns a batch of indices at a time. Mutually exclusive with
            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
            and :attr:`drop_last`.
        num_workers (int, optional): how many subprocesses to use for data
            loading. ``0`` means that the data will be loaded in the main process.
            (default: ``0``)
        collate_fn (Callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.
        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
            see the example below.
                drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
        worker_init_fn (Callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``)
        multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
            ``None``, the default `multiprocessing context`_ of your operating system will
            be used. (default: ``None``)
        generator (torch.Generator, optional): If not ``None``, this RNG will be used
            by RandomSampler to generate random indexes and multiprocessing to generate
            ``base_seed`` for workers. (default: ``None``)
        prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
            in advance by each worker. ``2`` means there will be a total of
            2 * num_workers batches prefetched across all workers. (default value depends
            on the set value for num_workers. If value of num_workers=0 default is ``None``.
            Otherwise, if value of ``num_workers > 0`` default is ``2``).
        persistent_workers (bool, optional): If ``True``, the data loader will not shut down
            the worker processes after a dataset has been consumed once. This allows to
            maintain the workers `Dataset` instances alive. (default: ``False``)
        pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
            ``True``.
"""

 

num_workers가 멀티프로세싱 지정하는 argument이다.

관련글 더보기

댓글 영역