인공지능/Pytorch

[Pytorch] torch.utils.data.Dataset

바보1 2023. 2. 25. 16:24

1. torch.utils.data.Dataset

 

 

  • torch.utils.data의 하위에 있는 Dataset은 Pytorch에서 데이터를 불러오고, 전처리하는 클래스입니다.
  • 따라서 사용자가 자신의 데이터셋을 만들어 Pytorch에서 사용할 수 있도록 해줍니다.
  • 이 클래스는 사용자 정의 데이터셋을 만들 수 있게 해주어 효율적인 데이터 로딩 및 전처리를 가능하게 합니다.
  • 해당 클래스는 사용자가 직접 구현해야 합니다. torch.utils.data.Dataset을 상속받아 직접 구현합니다.

2. 파라미터

 

 

  • __init__(self, ...) : 해당 클래스의 인스턴스를 초기화 합니다. 이 메서드에서는 데이터셋에서 필요한 인자를 받습니다. 예를 들어, 데이터셋의 경로, 이미지 크기, 데이터 전처리 방법 등이 있을 수 있습니다. 그 외에도 다양한 인자를 받아서 데이터셋에서 활용할 수 있는 augmentataion 옵션 등도 활용할 수 있습니다.
  • __len__(self) : 이 메서드는 데이터셋의 크기를 반환합니다. 일반적으로 len(dataset) 형태로 호출됩니다. 혹은 sampler의 구현, DataLoader의 기본 옵션에 맞게 데이터셋의 크기를 반환합니다.
  • __getitem__(self, index) :  인덱스를 입력받아 데이터셋에서 해당 인덱스를 반환합니다. dataset[index] 형태로 호출됩니다. 이 메서드는 특히 중요한데, 해당 메서드에서 데이터를 불러오고 전처리하는 작업이 이루어집니다. 이 때, 로드된 데이터는 일반적으로 Pytorch Tensor 형태로 반환됩니다.

이러한 메소드를 구현하여 사용자 정의 데이터셋 클래스를 만들 수 있습니다.

이후에는 torch.utils.data.DataLoader 클래스를 사용하여 데이터셋을 로드하고 이를 효율적으로 사용할 수 있습니다.


3. 예시

 

 

import torch
from torch.utils.data import Dataset


class MyDataset(Dataset):
    # 해당 클래스는 torch.utils.data.Dataset을 상속받습니다.
    def __init__(self, data, target):
        # 해당 클래스의 인스턴스를 초기화합니다.
        self.data = data
        self.target = target

    def __len__(self):
        # 해당 데이터셋의 총 데이터 수를 리턴합니다.
        return len(self.data)
    
    def __getitem__(self, idx):
        # 인덱스 idx에 해당하는 데이터를 가져옵니다.
        return self.data[idx], self.target[idx]
    

if __name__ == "__main__":
    data = torch.randn(100, 3, 32, 32)
    target = torch.randint(0, 10, (100, ))
    mydataset = MyDataset(data, target)

    myData, myTarget = mydataset[30]

    print(len(mydataset))
    print(myData.shape)
    print(myTarget)
>>> 100
>>> torch.Size([3, 32, 32])
>>> tensor(3)

4. 끝

 

 

다음 글에서는 torch.utils.data.DataLoader 클래스에 대해 알아보겠습니다.

감사합니다.

지적 환영합니다.

'인공지능 > Pytorch' 카테고리의 다른 글

[Pytorch] torch.utils.data.DataLoader  (0) 2023.03.13