Skip to content

DataLoader num_workers > 0 causes CPU memory from parent process to be replicated in all worker processes #13246

@bfreskura

Description

@bfreskura

Editor note: There is a known workaround further down on this issue, which is to NOT use Python lists, but instead using something else, e.g., torch.tensor directly. See #13246 (comment) . You can use a numpy array, but it only fixes the issue for the fork start method. See #13246 (comment) for more details

🐛 Bug

CPU memory will leak if the DataLoader num_workers > 0.

To Reproduce

Run the following snippet:

from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import os

class DataIter(Dataset):
    def __init__(self):
        path = "path/to/data"
        self.data = []

        for cls in os.listdir(path):
            for img in os.listdir(os.path.join(path, cls)):
                self.data.append(os.path.join(path, cls, img))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        with Image.open(self.data[idx]) as img:
            img = img.convert('RGB')
            return transforms.functional.to_tensor(img)


train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=False,
                          num_workers=18)

for i, item in enumerate(train_loader):
    if i % 200 == 0:
        print(i)

Expected behavior

CPU memory will gradually start increasing, eventually filling up the whole RAM. E.g., the process starts with around 15GB and fills up the whole 128GB available on the system.
When the num_workers=0, RAM usage is constant.

Environment

PyTorch version: 1.0.0.dev20181028
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti

Nvidia driver version: 390.67
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4

Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect

PIL.__version__
'5.3.0'

Additional info

There are around 24 million images in the dataset and all image paths are loaded into a single list as presented in the above code snippet.

I have also tried multiple Pytorch (0.4.0 and 0.4.1) versions and the effect is the same.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @ssnl @VitalyFedyunin @ejguan

Activity

ssnl

ssnl commented on Oct 29, 2018

@ssnl
Collaborator

Do you see memory usage increasing when iterating, or before you even start to iterate?

bfreskura

bfreskura commented on Oct 29, 2018

@bfreskura
Author

@ssnl During the iteration only.

ezyang

ezyang commented on Oct 29, 2018

@ezyang
Contributor

When we fix #13243 we should check if this one gets fixed too.

samgd

samgd commented on Oct 31, 2018

@samgd

I've been experiencing something similar where memory usage continuously climbs until a OOM is triggered when using a batch_sampler with num_workers>0.

To Reproduce

import math

from torch.utils.data import DataLoader


class Sampler:
    def __init__(self, n=100000, batch_size=32):
        self.n = n
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(float(self.n)/self.batch_size)

    def __iter__(self):
        batch = []
        for i in range(self.n):
            batch.append(i)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch:
            yield batch

            
N = 100000000
train_data = list(range(N))

            
def ok():
    train_sampler = Sampler(len(train_data))
    train_loader = DataLoader(train_data,
                              num_workers=0,
                              batch_sampler=train_sampler)
    
    for i, item in enumerate(train_loader):
        if i % 10000 == 0:
            print(i)
            
            
def leaky():
    train_sampler = Sampler(len(train_data))
    train_loader = DataLoader(train_data,
                              num_workers=8,
                              batch_sampler=train_sampler)

    for i, item in enumerate(train_loader):
        if i % 10000 == 0:
            print(i)
            
            
print('Starting ok')
ok()
print('ok done, starting leaky()')
leaky()
print('leaky done')

Environment

$ python3 collect_env.py
Collecting environment information...
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 9.1.85

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: GeForce GTX 1050 Ti with Max-Q Design
Nvidia driver version: 390.77
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.2
/usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a

Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect
bfreskura

bfreskura commented on Nov 7, 2018

@bfreskura
Author

@ezyang

When we fix #13243 we should check if this one gets fixed too.

The issue is still present in 1.0.0.dev20181105, where the #13243 is fixed.

bfreskura

bfreskura commented on Nov 7, 2018

@bfreskura
Author

After some more investigation, I have found an exact scenario when the leak occurs. Consider the code example below:

from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch


class DataIter(Dataset):
    def __init__(self):
        self.data_np = np.array([x for x in range(24000000)])
        self.data = [x for x in range(24000000)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        data = np.array([data], dtype=np.int64)
        return torch.tensor(data)


train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=False,
                          num_workers=18)

for i, item in enumerate(train_loader):
    if i % 1000 == 0:
        print(i)

If we use the self.data variable which is a standard Python list of ints, the data leak will occur. However, if the self.data_np variable is used, which holds the same data but in a form of a Numpy array, the leak will not occur.
Another observation is that the leakage is significantly less severe if the shuffle=False in the DataLoader.

svishnu88

svishnu88 commented on Nov 10, 2018

@svishnu88

I face similar issue, but in my case it occurs with numpy array too. I am using Python 3.7 and PyTorch nightly release.

mprostock

mprostock commented on Dec 8, 2018

@mprostock

I don't know how multiprocessing really works under the hood of pytorch, but we have extensively discussed this "Memory Leak" issue (which probably isn't a memory leak!) on the fast.ai forums (https://forums.fast.ai/t/runtimeerror-dataloader-worker-is-killed-by-signal/31277/55?u=marcmuc). Preliminary findings which hopefully add some insight here (if this does NOT apply, please comment!):

Python Multiprocessing: There is no way of storing arbitrary python objects (even simple lists) in shared memory in Python without triggering copy-on-write behaviour due to the addition of refcounts, everytime something reads from these objects. The refcounts are added memory-page by memory-page, which is why the consumption grows slowly. The processes (workers) will end up having all/most of the memory copied over bit by bit, which is why we get the memory overflow problem. Best description of this behavior is here (SO).

Possible Solution:
Using Multiprocessing like now: in order for python multiprocessing to work without these refcount effects, the objects have to be made “compatible with” and wrapped in multiprocessing.Array before the process pool is created and workers are forked. This supposedly ensures, that the memory will really be shared and no copy-on-write happens. This explains how to do it for numpy arrays and this explains the reasoning behind it again. Don’t get confused by some false statements even by the authors of these good answers stating that copy-on-write makes all of this unnecessary, which is not true. One comment also points to this:

“Just to note, on Python fork() actually means copy on access (because just accessing the object will change its ref-count).”

I am not familiar with the torch.multiprocessing drop-in replacement that I understand pytorch uses, but I would assume it will also not be able to remove the core python refcount issue.

soumith

soumith commented on Dec 9, 2018

@soumith
Member

@mprostock torch.multiprocessing is simply Python multiprocessing, with a custom pickler. The custom pickler, whenever it encounters a torch.tensor, will automatically move it to shared memory, and hence atleast on the torch.tensor objects, no copy-on-write happens.

mprostock

mprostock commented on Dec 10, 2018

@mprostock

Thanks for the explanation! I have experimented with @bfreskura 's reproduction example and I think I can now pinpoint the problem:

The reproduction example by bfreskura above showed the difference between a regular python list and a numpy array. But the problem is not (only) the python list itself, the same happens in a numpy array of type object. Python lists store only references to the objects, the objects are kept separately in memory. Every object has a refcount, therefore every item in the list has a refcount.

Numpy arrays (of standard np types) are stored as continuous blocks in memory and are only ONE object with one refcount.

This changes if you make the numpy array explicitly of type object, which makes it start behaving like a regular python list (only storing references to (string) objects). The same "problems" with memory consumption now appear.

This would explain, why with regular lists (or numpy arrays of type object) we see the "memory leak", which actually is the copy-on-acces problem of forked python processes due to changing refcounts, not a memory leak.

So the problem probably (often) has got nothing to do with tensors or actual torch objects, but rather with the lists of filenames and dicts of labels, that are generally used within dataloaders/datasets.

I have created a notebook gist, if someone wants to quickly try it.
Look at the memory consumption (quick and dirty mem of total system, so minor influences by other processes, tried to keep system clean)

Memory-Consumption in GB with fixed length string array:
image

Memory-Consumption in GB with object array (only change!)
image

aurooj

aurooj commented on Jan 15, 2019

@aurooj

I am facing the same issue. It fills up my RAM very fast if the num_workers > 0.
I am deleting the variables which I feel are no longer needed in my code, also call gc.collect() on every iteration, but nothing helps.
Any workarounds?

NProkoptsev

NProkoptsev commented on Jan 18, 2019

@NProkoptsev

Switching from dict to pandas and from lists to numpy arrays helps me

I am facing the same issue. It fills up my RAM very fast if the num_workers > 0.
I am deleting the variables which I feel are no longer needed in my code, also call gc.collect() on every iteration, but nothing helps.
Any workarounds?

312 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: dataloaderRelated to torch.utils.data.DataLoader and Samplermodule: dependency bugProblem is not caused by us, but caused by an upstream library we usemodule: memory usagePyTorch is using more memory than it should, or it is leaking memorymodule: molly-guardFeatures which help prevent users from committing common mistakesmodule: multiprocessingRelated to torch.multiprocessingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @ezyang@dmus@cowwoc@netheril96@eece-23

        Issue actions

          DataLoader num_workers > 0 causes CPU memory from parent process to be replicated in all worker processes · Issue #13246 · pytorch/pytorch