Skip to content

PyTorch bug:rebuilt parameter indices size is not same as original model parameters size.156 versus 22776 #47050

@zhiyuan1i

Description

@zhiyuan1i

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

1.conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

  1. install latest mmcv mmdetection
  2. ./tools/dist_train.sh configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py 2
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
2020-10-29 21:24:04,114 - mmdet - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.8.5 (default, Sep  4 2020, 07:30:14) [GCC 7.3.0]
CUDA available: True
GPU 0: P106-100
GPU 1: GeForce GTX 1060 6GB
CUDA_HOME: /usr
NVCC: Build cuda_11.0_bu.TC445_37.28845127_0
GCC: gcc (Ubuntu 10.2.0-13ubuntu1) 10.2.0
PyTorch: 1.7.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.5
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

TorchVision: 0.8.1
OpenCV: 4.4.0
MMCV: 1.1.6
MMCV Compiler: GCC 10.2
MMCV CUDA Compiler: 11.0
MMDetection: 2.5.0+ac20ffe
------------------------------------------------------------

2020-10-29 21:24:04,456 - mmdet - INFO - Distributed training: True
2020-10-29 21:24:04,794 - mmdet - INFO - Config:
model = dict(
    type='FasterRCNN',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=True,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[0.0, 0.0, 0.0, 0.0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=2,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0.0, 0.0, 0.0, 0.0],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))))
train_cfg = dict(
    rpn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.7,
            neg_iou_thr=0.3,
            min_pos_iou=0.3,
            match_low_quality=True,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    rpn_proposal=dict(
        nms_across_levels=False,
        nms_pre=2000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.5,
            min_pos_iou=0.5,
            match_low_quality=False,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=512,
            pos_fraction=0.25,
            neg_pos_ub=-1,
            add_gt_as_proposals=True),
        pos_weight=-1,
        debug=False))
test_cfg = dict(
    rpn=dict(
        nms_across_levels=False,
        nms_pre=1000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        score_thr=0.05,
        nms=dict(type='nms', iou_threshold=0.5),
        max_per_img=100))
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_train2017.json',
        img_prefix='data/coco/train2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', with_bbox=True),
            dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
        ]),
    val=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_val2017.json',
        img_prefix='data/coco/val2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]),
    test=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_val2017.json',
        img_prefix='data/coco/val2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]))
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])
total_epochs = 12
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
work_dir = './work_dirs/faster_rcnn_r50_fpn_1x_coco'
gpu_ids = range(0, 1)

2020-10-29 21:24:05,218 - mmdet - INFO - load model from: torchvision://resnet50
2020-10-29 21:24:05,519 - mmdet - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: fc.weight, fc.bias

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Done (t=0.00s)
creating index...
index created!
2020-10-29 21:24:06,086 - mmdet - INFO - Start running, host: lizhiyuan@lizhiyuan-Server, work_dir: /mnt/Data/workspace/mmdetection/work_dirs/faster_rcnn_r50_fpn_1x_coco
2020-10-29 21:24:06,086 - mmdet - INFO - workflow: [('train', 1)], max: 12 epochs
2020-10-29 21:24:49,980 - mmdet - INFO - Epoch [1][50/146]	lr: 1.978e-03, eta: 0:24:49, time: 0.875, data_time: 0.125, memory: 3324, loss_rpn_cls: 0.3634, loss_rpn_bbox: 0.0203, loss_cls: 0.2062, acc: 95.9912, loss_bbox: 0.0038, loss: 0.5936
2020-10-29 21:25:28,628 - mmdet - INFO - Epoch [1][100/146]	lr: 3.976e-03, eta: 0:22:41, time: 0.773, data_time: 0.008, memory: 3325, loss_rpn_cls: 0.1206, loss_rpn_bbox: 0.0171, loss_cls: 0.0500, acc: 99.2227, loss_bbox: 0.0144, loss: 0.2021
2020-10-29 21:26:04,092 - mmdet - INFO - Saving checkpoint at 1 epochs
[                                                  ] 0/80, elapsed: 0s, ETA:Traceback (most recent call last):
  File "./tools/train.py", line 178, in <module>
    main()
  File "./tools/train.py", line 167, in main
    train_detector(
  File "/mnt/Data/workspace/mmdetection/mmdet/apis/train.py", line 150, in train_detector
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 125, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 54, in train
    self.call_hook('after_train_epoch')
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/mmcv/runner/base_runner.py", line 307, in call_hook
    getattr(hook, fn_name)(self)
  File "/mnt/Data/workspace/mmdetection/mmdet/core/evaluation/eval_hooks.py", line 125, in after_train_epoch
    results = multi_gpu_test(
  File "/mnt/Data/workspace/mmdetection/mmdet/apis/test.py", line 97, in multi_gpu_test
    result = model(return_loss=False, rescale=True, **data)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 606, in forward
    if self.reducer._rebuild_buckets():
RuntimeError: replicas_[0].size() == rebuilt_param_indices_.size() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1603729062494/work/torch/csrc/distributed/c10d/reducer.cpp":1326, please report a bug to PyTorch. rebuilt parameter indices size is not same as original model parameters size.156 versus 22776
Traceback (most recent call last):
  File "./tools/train.py", line 178, in <module>
    main()
  File "./tools/train.py", line 167, in main
    train_detector(
  File "/mnt/Data/workspace/mmdetection/mmdet/apis/train.py", line 150, in train_detector
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 125, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 54, in train
    self.call_hook('after_train_epoch')
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/mmcv/runner/base_runner.py", line 307, in call_hook
    getattr(hook, fn_name)(self)
  File "/mnt/Data/workspace/mmdetection/mmdet/core/evaluation/eval_hooks.py", line 125, in after_train_epoch
    results = multi_gpu_test(
  File "/mnt/Data/workspace/mmdetection/mmdet/apis/test.py", line 97, in multi_gpu_test
    result = model(return_loss=False, rescale=True, **data)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 606, in forward
    if self.reducer._rebuild_buckets():
RuntimeError: replicas_[0].size() == rebuilt_param_indices_.size() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1603729062494/work/torch/csrc/distributed/c10d/reducer.cpp":1326, please report a bug to PyTorch. rebuilt parameter indices size is not same as original model parameters size.156 versus 22776
Traceback (most recent call last):
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/torch/distributed/launch.py", line 260, in <module>
    main()
  File "/home/lizhiyuan/anaconda3/envs/mmcv/lib/python3.8/site-packages/torch/distributed/launch.py", line 255, in main
    raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/home/lizhiyuan/anaconda3/envs/mmcv/bin/python', '-u', './tools/train.py', '--local_rank=1', 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py', '--launcher', 'pytorch']' returned non-zero exit status 1.

Expected behavior

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.10 (x86_64)
GCC version: (Ubuntu 10.2.0-13ubuntu1) 10.2.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: P106-100
GPU 1: GeForce GTX 1060 6GB

Nvidia driver version: 450.80.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.7.0
[pip3] torchaudio==0.7.0a0+ac17b64
[pip3] torchvision==0.8.1
[conda] blas 1.0 mkl defaults
[conda] cudatoolkit 10.2.89 hfd86e86_1 defaults
[conda] mkl 2020.2 256 defaults
[conda] mkl-service 2.3.0 py38he904b0f_0 defaults
[conda] mkl_fft 1.2.0 py38h23d657b_0 defaults
[conda] mkl_random 1.1.1 py38h0573a6f_0 defaults
[conda] numpy 1.19.2 py38h54aff64_0 defaults
[conda] numpy-base 1.19.2 py38hfa32c7d_0 defaults
[conda] pytorch 1.7.0 py3.8_cuda10.2.89_cudnn7.6.5_0 pytorch
[conda] torchaudio 0.7.0 py38 pytorch
[conda] torchvision 0.8.1 py38_cu102 pytorch

Additional context

cc @ezyang @gchanan @zou3519 @bdhirsh @heitorschueroff @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski

Activity

mrshenli

mrshenli commented on Nov 2, 2020

@mrshenli
Contributor

Hey @uniartisan, have you tried if the same model works with PyTorch 1.6? This would help us narrow possible causes. Thanks.

@pritamdamania87 I wonder if we should revert gradient_as_bucket_view in 1.7.1.

zhiyuan1i

zhiyuan1i commented on Nov 2, 2020

@zhiyuan1i
Author

Hey @uniartisan, have you tried if the same model works with PyTorch 1.6? This would help us narrow possible causes. Thanks.

@pritamdamania87 I wonder if we should revert gradient_as_bucket_view in 1.7.1.

I used to use Pytorch 1.6 last week, and it seemed to be fine.
In order to test Pytorch 1.6, I tried the following scheme.
1.

 conda install pytorch=1.6.0 torchvision torchaudio cudatoolkit=10.2 -c pytorch

 environment location: /home/lizhiyuan/anaconda3/envs/mmcv

  added / updated specs:
    - cudatoolkit=10.2
    - pytorch=1.6.0
    - torchaudio
    - torchvision


The following packages will be DOWNGRADED:

  pytorch              1.7.0-py3.8_cuda10.2.89_cudnn7.6.5_0 --> 1.6.0-py3.8_cuda10.2.89_cudnn7.6.5_0
  torchaudio                                     0.7.0-py38 --> 0.6.0-py38
  torchvision                              0.8.1-py38_cu102 --> 0.7.0-py38_cu102
pip uninstall mmcv-full
pip install mmcv-full
  1. run my program, I get following output:
(mmcv) :/mnt/Data/workspace/homework_deeplearning$ ./tools/dist_train.sh configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py  2
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
2020-11-03 00:17:29,852 - mmdet - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.8.5 (default, Sep  4 2020, 07:30:14) [GCC 7.3.0]
CUDA available: True
GPU 0: P106-100
GPU 1: GeForce GTX 1060 6GB
CUDA_HOME: /usr
NVCC: Build cuda_11.0_bu.TC445_37.28845127_0
GCC: gcc (Ubuntu 10.2.0-13ubuntu1) 10.2.0
PyTorch: 1.6.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.5.0 (Git Hash e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.5
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF, 

TorchVision: 0.7.0
OpenCV: 4.4.0
MMCV: 1.1.6
MMCV Compiler: GCC 10.2
MMCV CUDA Compiler: 11.0
MMDetection: 2.5.0+e054208
------------------------------------------------------------

2020-11-03 00:17:30,200 - mmdet - INFO - Distributed training: True
2020-11-03 00:17:30,476 - mmdet - INFO - Config:
model = dict(
    type='FasterRCNN',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=True,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[0.0, 0.0, 0.0, 0.0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=2,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0.0, 0.0, 0.0, 0.0],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))))
train_cfg = dict(
    rpn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.7,
            neg_iou_thr=0.3,
            min_pos_iou=0.3,
            match_low_quality=True,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    rpn_proposal=dict(
        nms_across_levels=False,
        nms_pre=2000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.5,
            min_pos_iou=0.5,
            match_low_quality=False,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=512,
            pos_fraction=0.25,
            neg_pos_ub=-1,
            add_gt_as_proposals=True),
        pos_weight=-1,
        debug=False))
test_cfg = dict(
    rpn=dict(
        nms_across_levels=False,
        nms_pre=1000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        score_thr=0.05,
        nms=dict(type='nms', iou_threshold=0.5),
        max_per_img=100))
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_train2017.json',
        img_prefix='data/coco/train2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', with_bbox=True),
            dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
        ]),
    val=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_val2017.json',
        img_prefix='data/coco/val2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]),
    test=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_val2017.json',
        img_prefix='data/coco/val2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]))
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])
total_epochs = 12
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
work_dir = './work_dirs/faster_rcnn_r50_fpn_1x_coco'
gpu_ids = range(0, 1)

2020-11-03 00:17:30,935 - mmdet - INFO - load model from: torchvision://resnet50
2020-11-03 00:17:33,160 - mmdet - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: fc.weight, fc.bias

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Done (t=0.00s)
creating index...
index created!
2020-11-03 00:17:33,723 - mmdet - INFO - Start running, host: lizhiyuan@lizhiyuan-Server, work_dir: /mnt/Data/workspace/homework_deeplearning/work_dirs/faster_rcnn_r50_fpn_1x_coco
2020-11-03 00:17:33,723 - mmdet - INFO - workflow: [('train', 1)], max: 12 epochs
2020-11-03 00:18:19,676 - mmdet - INFO - Epoch [1][50/146]	lr: 1.978e-03, eta: 0:25:58, time: 0.916, data_time: 0.125, memory: 3323, loss_rpn_cls: 0.4061, loss_rpn_bbox: 0.0207, loss_cls: 0.1945, acc: 93.1475, loss_bbox: 0.0070, loss: 0.6282
2020-11-03 00:18:59,637 - mmdet - INFO - Epoch [1][100/146]	lr: 3.976e-03, eta: 0:23:36, time: 0.799, data_time: 0.009, memory: 3323, loss_rpn_cls: 0.1201, loss_rpn_bbox: 0.0168, loss_cls: 0.0453, acc: 99.2695, loss_bbox: 0.0123, loss: 0.1945
2020-11-03 00:19:36,161 - mmdet - INFO - Saving checkpoint at 1 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 80/80, 7.5 task/s, elapsed: 11s, ETA:     0s

2020-11-03 00:19:49,968 - mmdet - INFO - Evaluating bbox...
Loading and preparing results...
DONE (t=0.00s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=0.03s).
Accumulating evaluation results...
DONE (t=0.01s).
Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.000
Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.000
Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.000
Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = -1.000
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.000
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.000
Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.000
Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = -1.000
Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = -1.000
2020-11-03 00:19:50,014 - mmdet - INFO - Exp name: faster_rcnn_r50_fpn_1x_coco.py
2020-11-03 00:19:50,014 - mmdet - INFO - Epoch(val) [1][146]	bbox_mAP: 0.0000, bbox_mAP_50: 0.0000, bbox_mAP_75: 0.0000, bbox_mAP_s: 0.0000, bbox_mAP_m: -1.0000, bbox_mAP_l: -1.0000, bbox_mAP_copypaste: 0.000 0.000 0.000 0.000 -1.000 -1.000

It seems to be fine again,I think the new version of the changes caused this error

rohan-varma

rohan-varma commented on Nov 2, 2020

@rohan-varma
Member

@mrshenli Shouldn't gradient_as_bucket_view be fully hidden by a flag, and there should be no impact to DDP if the flag is the default option (False)?

mrshenli

mrshenli commented on Nov 2, 2020

@mrshenli
Contributor

@uniartisan thanks for the confirmation!

Shouldn't gradient_as_bucket_view be fully hidden by a flag, and there should be no impact to DDP if the flag is the default option (False)?

@rohan-varma Yep, it shouldn't. But my worry is that, as it re-organizes some code in DDP/Reducer, I wonder if that would lead to any error. This one looks relevant.

pritamdamania87

pritamdamania87 commented on Nov 2, 2020

@pritamdamania87
Contributor

@pritamdamania87 I wonder if we should revert gradient_as_bucket_view in 1.7.1.

I'm looking into this issue and will post an update when I find the root cause.

pritamdamania87

pritamdamania87 commented on Nov 3, 2020

@pritamdamania87
Contributor

@uniartisan I tried to repro the problem with your instructions, but the job seems to be running fine on 1.7: https://gist.github.com/pritamdamania87/5bc75fa5176219c7afd76cf918793fd4. I'm wondering which coco dataset you used here? I downloaded the ones from https://cocodataset.org/#download and those seem to be much larger and take quite a bit of time to finish.

zhiyuan1i

zhiyuan1i commented on Nov 3, 2020

@zhiyuan1i
Author

@uniartisan I tried to repro the problem with your instructions, but the job seems to be running fine on 1.7: https://gist.github.com/pritamdamania87/5bc75fa5176219c7afd76cf918793fd4. I'm wondering which coco dataset you used here? I downloaded the ones from https://cocodataset.org/#download and those seem to be much larger and take quite a bit of time to finish.

I did not download the coco data set, I made a copy myself, if you need it, I can give it to you in some way.
My cuda is installed directly through apt of ubuntu, the version is 11, and pytorch is based on the official tutorial. The specific problem occurs when saving each epoch.
Maybe my environment is different: my two graphics cards are different, one 1060 and one p106-100, but their chipsets and computing units are the same.
In addition, my mmdetection is the latest version of github clone.

I directly provide you with my data set and two modified mmdetection self-folder compression.
https://drive.google.com/drive/folders/1s69B3GxBoeaZ6zA3fEm6T5Wybg4o21B1?usp=sharing

pritamdamania87

pritamdamania87 commented on Nov 3, 2020

@pritamdamania87
Contributor

The specific problem occurs when saving each epoch.

I see, this is probably why I couldn't reproduce since my runs never reached to that point. I'll use the data you provided and try to repro the issue.

pritamdamania87

pritamdamania87 commented on Nov 4, 2020

@pritamdamania87
Contributor

@uniartisan Looks like this might be an issue with how mmcv is using DDP. MMDistributedDataParallel actually uses the reducer directly and mimics some of the DDP logic. This is really not supported and frameworks shouldn't be using the reducer directly. The problem occurred as we moved some code around for the reducer in #44798 which broke MMDistributedDataParallel.

If you add this change to mmcv, the issue is resolved (although mmcv shouldn't be doing this in the first place):

diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py
index 07771a3..60e2f31 100644
--- a/mmcv/parallel/distributed.py
+++ b/mmcv/parallel/distributed.py
@@ -28,6 +28,9 @@ class MMDistributedDataParallel(DistributedDataParallel):
         ``self.module.forward()`` with ``self.module.train_step()``.
         It is compatible with PyTorch 1.1 - 1.5.
         """
+        if self.reducer._rebuild_buckets():
+            print("Reducer buckets have been rebuilt in this iteration.")
+
         if getattr(self, 'require_forward_param_sync', True):
             self._sync_params()
         if self.device_ids:
zhiyuan1i

zhiyuan1i commented on Nov 4, 2020

@zhiyuan1i
Author

@uniartisan Looks like this might be an issue with how mmcv is using DDP. MMDistributedDataParallel actually uses the reducer directly and mimics some of the DDP logic. This is really not supported and frameworks shouldn't be using the reducer directly. The problem occurred as we moved some code around for the reducer in #44798 which broke MMDistributedDataParallel.

If you add this change to mmcv, the issue is resolved (although mmcv shouldn't be doing this in the first place):

diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py
index 07771a3..60e2f31 100644
--- a/mmcv/parallel/distributed.py
+++ b/mmcv/parallel/distributed.py
@@ -28,6 +28,9 @@ class MMDistributedDataParallel(DistributedDataParallel):
         ``self.module.forward()`` with ``self.module.train_step()``.
         It is compatible with PyTorch 1.1 - 1.5.
         """
+        if self.reducer._rebuild_buckets():
+            print("Reducer buckets have been rebuilt in this iteration.")
+
         if getattr(self, 'require_forward_param_sync', True):
             self._sync_params()
         if self.device_ids:

Thank you for your hard work, so should I go to mmcv to report this issue now?

pritamdamania87

pritamdamania87 commented on Nov 4, 2020

@pritamdamania87
Contributor

Thank you for your hard work, so should I go to mmcv to report this issue now?

Yes, I think we should file an issue on mmcv and reference this issue.

zhiyuan1i

zhiyuan1i commented on Nov 5, 2020

@zhiyuan1i
Author

Thank you for your hard work, so should I go to mmcv to report this issue now?

Yes, I think we should file an issue on mmcv and reference this issue.

Thanks a lot :)

Hi @uniartisan ,
Thanks for your bug report. We will fix the MMDDP to catch up with PyTorch 1.7. As a workaround, we suggest you use PyTorch 1.6 for now.

Originally posted by @ZwwWayne in open-mmlab/mmcv#636 (comment)

I will close this issue later.
Thank you very much for your outstanding work:)

tchaton

tchaton commented on Feb 5, 2021

@tchaton

@uniartisan Looks like this might be an issue with how mmcv is using DDP. MMDistributedDataParallel actually uses the reducer directly and mimics some of the DDP logic. This is really not supported and frameworks shouldn't be using the reducer directly. The problem occurred as we moved some code around for the reducer in #44798 which broke MMDistributedDataParallel.

If you add this change to mmcv, the issue is resolved (although mmcv shouldn't be doing this in the first place):

diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py
index 07771a3..60e2f31 100644
--- a/mmcv/parallel/distributed.py
+++ b/mmcv/parallel/distributed.py
@@ -28,6 +28,9 @@ class MMDistributedDataParallel(DistributedDataParallel):
         ``self.module.forward()`` with ``self.module.train_step()``.
         It is compatible with PyTorch 1.1 - 1.5.
         """
+        if self.reducer._rebuild_buckets():
+            print("Reducer buckets have been rebuilt in this iteration.")
+
         if getattr(self, 'require_forward_param_sync', True):
             self._sync_params()
         if self.device_ids:

Hey there, I kind of disagree there. We should be able to use the reducer to make DDP more flexible. See manual optimization in PyTorch Lightning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @albanD@rohan-varma@pritamdamania87@tchaton@mrshenli

        Issue actions

          PyTorch bug:rebuilt parameter indices size is not same as original model parameters size.156 versus 22776 · Issue #47050 · pytorch/pytorch