PyTorch Dataset Caching: 학습 시간 최적화

PyTorch Logo

개요

이미지 I/O는 딥러닝 모델 학습에 있어 가장 큰 병목이 되는 지점 중 하나 입니다. 이 글에서는 이미지 데이터를 Caching 하여 학습 속도로를 최적화 하는 방법을 소개합니다.

PyTorch Dataset Caching

샘플 데이터셋 코드

class SampleDataset(VisionDataset):   
    def __init__(
        self, 
        root_dir:str, 
        shared_dict=None,
        use_caching=False,
    ) -> None:
        super(SampleDataset, self).__init__(root_dir)
        
        self.caching_data = shared_dict # 공유 dict
        self.use_caching = use_caching  # Caching 사용 여부
        self.image_paths = glob(f'{root_dir}/**/*.jpg', recursive=True)  #전체 이미지 파일 경로 리스트
        
    def __getitem__(self, index:int):        
        if index not in self.caching_data or not self.use_caching:
            self.caching_data[index] = cv2.imread(self.image_paths[index])

        return self.caching_data[index], 0

    def __len__(self) -> int:
        return len(self.image_paths)

테스트를 위해 Torchvision 패키지의 VisionDataset을 상속해 커스텀 데이터셋을 위와 같이 작성하였습니다.

세부 코드 설명

self.image_paths = glob(f'{root_dir}/**/*.jpg', recursive=True)  #전체 이미지 파일 경로 리스트

glob 패키지를 이용해 하위 폴더에 있는 이미지(*.jpg) 파일들의 경로를 리스트 형태로 얻습니다.

def __getitem__(self, index:int):        
    if index not in self.caching_data or not self.use_caching:
        self.caching_data[index] = cv2.imread(self.image_paths[index])

    return self.caching_data[index], 0

self.__getitem__ 함수는 인덱스를 이용해 데이터를 호출할 때 사용되는 기본 Python 메서드 입니다. 만약 인덱스 접근에 의해 위 함수가 호출 되었을 때 self.caching_data:dict 에 해당하는 index 번호 데이터가 들어 있지 않거나 Caching을 사용하지 않는 다면 opencv를 이용해 이미지를 디스크로부터 읽어 self.caching_data 에 넣어 줍니다. 반대로 해당하는 index의 데이터가 존재하고 Caching을 사용하는 상태라면 self.caching_data 의 해당하는 index 데이터를 반환합니다.

따라서, 이런 방식의 구현으로 인하여 두번째 Epoch 부터는 디스크에 있는 이미지가 아니라 메모리에 있는 데이터를 직접 사용함으로 이미지 I/O에 대한 병목이 줄어들게 됩니다.

테스트 및 결과

if __name__ == "__main__":
    
    BATCH_SIZE = 16
    manager = Manager()
    shared_dict = manager.dict()

    datasets = SampleDataset('/datasets', shared_dict=shared_dict, use_caching=True)
    dataset_loader = DataLoader(datasets, batch_size=BATCH_SIZE, num_workers=32)
    
    epochs = 10
    for epoch in range(epochs):
        i = 0
        sTime = time.time()
        for clips, target in dataset_loader:        
            print(f'[epoch {epoch+1}] {i}/{int(len(datasets)/BATCH_SIZE)}', end='\r')
            i += 1

        print(f'[epoch {epoch+1}] Done !! - {time.time() - sTime} sec')

위 코드를 이용하여 720×480 해상도 28,530 장의 이미지를 Load하는 테스트를 진행하였습니다.

image 9
Caching 사용 여부에 따른 속도 비교

위 테이블은 Caching 사용 여부에 따른 데이터 로드 속도를 비교한 것 입니다. 첫번째 Epoch에서는 두방법 모두 디스크로부터 데이터를 읽어 오기 때문에 속도 차이가 거의 없는 것을 확인 할 수 있습니다.

그러나 두번째 Epoch 부터는 Caching을 사용하였을 시에 데이터를 메모리로 부터 읽어오기 때문에 굉장히 큰 속도 차이가 나는 것을 확인 할 수 있었습니다. 매 Epoch 마다 약 20초씩 절약 한다면 300epcoh 학습시 약 1시간 40분 정도의 시간을 절약 할 수 있음으로 매우 큰 수치라고 할 수 있습니다.

신기한 점으로는 Caching을 미사용 했을 시에도 Epoch에 따라 데이터 로드 시간이 감소함을 확인할 수 있었습니다. 이는 계속해서 같은 위치의 데이터를 I/O 함에 따라 OS에서 이를 최적화 하는 것으로 추정할 수 있습니다(Disk I/O의 Hit Rate 증가).

메모리 관련 주의 사항

위 방법의 경우 첫번째 Epoch에서 모든 데이터를 메모리에 올리게 됨으로 메모리가 충분하지 않다면 사용할 수 없는 방법입니다(저는 이 테스트에서 약 41 GB의 메모리를 사용하는 것을 확인하였습니다).

다만 메모리가 적더라도 데이터셋의 일부만 이더라도 메모리에 올릴 수 있다면 그만큼 로드 시간을 절약할 수 있습니다. 따라서 psutil과 같은 패키지를 이용해 메모리의 사용량을 체크하여 메모리 사용량이 90% 이하일 때 까지만 데이터를 캐싱하는 방법을 사용할 수 있습니다.

PyTorch의 DataLoader는 num_workers의 숫자에 따라 프로세스를 생성하고 모든 Dataset을 그만큼 복사합니다(num_workers가 4라면 4개의 프로세스를 생성하고 프로세스마다 데이터셋을 복사). 그렇기 때문에, Python의 multiprocessing 패키지의 Manager.dict()을 사용하여 모든 프로세스가 같은 dictionary를 공유할 수 있도록 구현해야합니다. 만약 이렇게 구현하지 않는다면 로드 해야할 데이터가 20GB라면 그 만큼이 프로세스 마다 복사됩니다.

전체 코드

import time
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset

from glob import glob

import cv2
from multiprocessing import Manager

class SampleDataset(VisionDataset):   
    def __init__(
        self, 
        root_dir:str, 
        shared_dict=None,
        use_caching=False,
    ) -> None:
        super(SampleDataset, self).__init__(root_dir)
        
        self.caching_datas = shared_dict
        self.use_caching = use_caching
        self.image_paths = glob(f'{root_dir}/**/*.jpg', recursive=True)

        
    def __getitem__(self, index:int):
        
        if index not in self.caching_datas or not self.use_caching:
            self.caching_datas[index] = cv2.imread(self.image_paths[index])

        return self.caching_datas[index], 0

    def __len__(self) -> int:
        return len(self.image_paths)


if __name__ == "__main__":
    
    BATCH_SIZE = 16
    manager = Manager()
    shared_dict = manager.dict()

    datasets = SampleDataset('/datasets', shared_dict=shared_dict, use_caching=True)
    dataset_loader = DataLoader(datasets, batch_size=BATCH_SIZE, num_workers=32)
    
    epochs = 10
    for epoch in range(epochs):
        i = 0
        sTime = time.time()
        for clips, target in dataset_loader:        
            print(f'[epoch {epoch+1}] {i}/{int(len(datasets)/BATCH_SIZE)}', end='\r')
            i += 1

        print(f'[epoch {epoch+1}] Done !! - {time.time() - sTime} sec')

댓글 달기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다

위로 스크롤