1. torch.utils.data.Dataset
- torch.utils.data의 하위에 있는 Dataset은 Pytorch에서 데이터를 불러오고, 전처리하는 클래스입니다.
- 따라서 사용자가 자신의 데이터셋을 만들어 Pytorch에서 사용할 수 있도록 해줍니다.
- 이 클래스는 사용자 정의 데이터셋을 만들 수 있게 해주어 효율적인 데이터 로딩 및 전처리를 가능하게 합니다.
- 해당 클래스는 사용자가 직접 구현해야 합니다. torch.utils.data.Dataset을 상속받아 직접 구현합니다.
2. 파라미터
- __init__(self, ...) : 해당 클래스의 인스턴스를 초기화 합니다. 이 메서드에서는 데이터셋에서 필요한 인자를 받습니다. 예를 들어, 데이터셋의 경로, 이미지 크기, 데이터 전처리 방법 등이 있을 수 있습니다. 그 외에도 다양한 인자를 받아서 데이터셋에서 활용할 수 있는 augmentataion 옵션 등도 활용할 수 있습니다.
- __len__(self) : 이 메서드는 데이터셋의 크기를 반환합니다. 일반적으로 len(dataset) 형태로 호출됩니다. 혹은 sampler의 구현, DataLoader의 기본 옵션에 맞게 데이터셋의 크기를 반환합니다.
- __getitem__(self, index) : 인덱스를 입력받아 데이터셋에서 해당 인덱스를 반환합니다. dataset[index] 형태로 호출됩니다. 이 메서드는 특히 중요한데, 해당 메서드에서 데이터를 불러오고 전처리하는 작업이 이루어집니다. 이 때, 로드된 데이터는 일반적으로 Pytorch Tensor 형태로 반환됩니다.
이러한 메소드를 구현하여 사용자 정의 데이터셋 클래스를 만들 수 있습니다.
이후에는 torch.utils.data.DataLoader 클래스를 사용하여 데이터셋을 로드하고 이를 효율적으로 사용할 수 있습니다.
3. 예시
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 해당 클래스는 torch.utils.data.Dataset을 상속받습니다.
def __init__(self, data, target):
# 해당 클래스의 인스턴스를 초기화합니다.
self.data = data
self.target = target
def __len__(self):
# 해당 데이터셋의 총 데이터 수를 리턴합니다.
return len(self.data)
def __getitem__(self, idx):
# 인덱스 idx에 해당하는 데이터를 가져옵니다.
return self.data[idx], self.target[idx]
if __name__ == "__main__":
data = torch.randn(100, 3, 32, 32)
target = torch.randint(0, 10, (100, ))
mydataset = MyDataset(data, target)
myData, myTarget = mydataset[30]
print(len(mydataset))
print(myData.shape)
print(myTarget)
>>> 100
>>> torch.Size([3, 32, 32])
>>> tensor(3)
4. 끝
다음 글에서는 torch.utils.data.DataLoader 클래스에 대해 알아보겠습니다.
감사합니다.
지적 환영합니다.
'인공지능 > Pytorch' 카테고리의 다른 글
[Pytorch] torch.utils.data.DataLoader (0) | 2023.03.13 |
---|