Skip to content

tf.nn.separable_conv2d is slower than conv2d on GPU #12940

Closed
@ghost

Description

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): TF 1.3
  • Python version: 3.6
  • Bazel version (if compiling from source):
  • CUDA/cuDNN version: CUDA8.0 /cuDNN6
  • GPU model and memory: GTX1080ti 11G

Describe the problem

In theory, separable_conv2d should be more efficient than conv2d, but when I test a simple model on Cifar10, the result shows that nn.separable_conv2d run slower on GPU, but is indeed faster on CPU.

Here is my test results on GPU:

training time for normal_conv after 2000 step: 8.18395892999979 sec
time for normal_conv after one forward step:  0.003980965999289765 sec
training time for separable_conv after 2000 step: 9.158266903999902 sec
time for separable_conv after one forward step:  0.0036441169995669043 sec

Source code / logs

Below is a fully self-contained example, I first define a model with two conv2d , than I define another model with one conv2d followed by one separable_conv2d, both model have 32 channels for each conv_layer and identical fc_layer.

import tensorflow as tf
import timeit
import numpy as np
from tensorflow.contrib.keras.python.keras.datasets.cifar10 import load_data

(x_train, y_train), (x_val, y_val) = load_data()
learning_rate = 0.001
num_steps = 1000
n_classes = 10
batch_size = 32

def reformat(labels):
    # Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
    labels = (np.arange(n_classes) == labels[:,None]).astype(np.float32)
    return  labels.reshape(labels.shape[0],10)
train_labels = reformat(y_train)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 32, 32, 3])
y = tf.placeholder(tf.float32, [None, 10])
weights1 = {}
weights2 = {}
dtype = tf.float32
with tf.name_scope('INIT_OP'):
    conv_initializer =  tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
    fc_initializer =  tf.contrib.layers.xavier_initializer(dtype=dtype)

k = 3
kernel = 16

# Define weights for normal ConvNet
with tf.name_scope('VARIABLES_1'):
    weights1['conv1'] = tf.get_variable('conv1', [k, k, 3, kernel], initializer=conv_initializer, dtype=dtype, trainable=True)
    weights1['b1'] = tf.get_variable('b1', initializer=tf.zeros([kernel]))
    weights1['conv2'] = tf.get_variable('conv2', [k, k, kernel, kernel], initializer=conv_initializer, dtype=dtype, trainable=True)
    weights1['b2'] = tf.get_variable('b2', initializer=tf.zeros([kernel]))

    weights1['wd1'] = tf.get_variable('wd1', [8*8*kernel, 512], initializer=fc_initializer, dtype=dtype, trainable=True)
    weights1['bd1'] = tf.get_variable('bd1',  initializer=tf.zeros([512]) )
    weights1['wd2'] = tf.get_variable('wd2', [512, 10], initializer=fc_initializer, dtype=dtype, trainable=True)
    weights1['bd2'] = tf.get_variable('bd2',  initializer=tf.zeros([10]) )


#Define weights for separable ConvNet
with tf.name_scope('VARIABLES_sep'):
    weights2['conv1'] = tf.get_variable('2_conv1', [k, k, 3, kernel], initializer=conv_initializer, dtype=dtype, trainable=True)
    weights2['conv_dw2'] = tf.get_variable('conv_dw2', [k, k, kernel, 1], initializer=conv_initializer, dtype=dtype, trainable=True)
    weights2['conv_pw2'] = tf.get_variable('conv_pw2', [1, 1, kernel, kernel], initializer=conv_initializer, dtype=dtype, trainable=True)

    weights2['b1'] = tf.get_variable('2_b1', initializer=tf.zeros([kernel]))
    weights2['b2'] = tf.get_variable('2_b2', initializer=tf.zeros([kernel]))

    weights2['wd1'] = tf.get_variable('2_wd1', [8*8*kernel, 512], initializer=fc_initializer, dtype=dtype, trainable=True)
    weights2['bd1'] = tf.get_variable('2_bd1',  initializer=tf.zeros([512]) )
    weights2['wd2'] = tf.get_variable('2_wd2', [512, 10], initializer=fc_initializer, dtype=dtype, trainable=True)
    weights2['bd2'] = tf.get_variable('2_bd2',  initializer=tf.zeros([10]) )

def forward_conv_sep( inp, weights):
    hidden = conv_block(inp, weights2['conv1'], weights2['b1'])
    hidden = maxpool2d(hidden)
    hidden = conv_block_dw(hidden, weights2['conv_dw2'], weights2['conv_pw2'], weights2['b2'])
    hidden = maxpool2d(hidden)
    hidden = tf.reshape( hidden, [-1, np.prod([int(dim) for dim in hidden.get_shape()[1:]])] )
    fc1 = tf.matmul(hidden, weights2['wd1']) + weights2['bd1']
    fc1 = tf.nn.relu(fc1)
    return tf.matmul(fc1, weights2['wd2']) + weights2['bd2']

def forward_conv( inp, weights):
    hidden = conv_block(inp, weights1['conv1'], weights1['b1'])
    hidden = maxpool2d(hidden)
    hidden = conv_block(hidden, weights1['conv2'], weights1['b2'])
    hidden = maxpool2d(hidden)
    hidden = tf.reshape( hidden, [-1, np.prod([int(dim) for dim in hidden.get_shape()[1:]])] )
    fc1 = tf.matmul(hidden, weights1['wd1']) + weights1['bd1']
    fc1 = tf.nn.relu(fc1)
    return tf.matmul(fc1, weights1['wd2']) + weights1['bd2']


def conv_block_dw(inp, cweight_w, cweight_p, bweight):
    no_stride =  [1,1,1,1]
    conv_output = tf.nn.separable_conv2d(inp, cweight_w, cweight_p, no_stride, 'SAME') + bweight
    return tf.nn.relu(conv_output)

def conv_block(inp, cweight, bweight, activation=tf.nn.relu):
    no_stride =  [1,1,1,1]
    conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight
    return tf.nn.relu(conv_output)

def maxpool2d(inp, k=2):
    return tf.nn.max_pool(inp, ksize=[1, k, k, 1], strides=[1, k, k, 1],
                          padding='SAME')

#logits for normal ConvNet
with tf.name_scope("forward_conv"):
    pred1 = forward_conv(x, weights1)

#Cost for normal ConvNet
with tf.name_scope("cost1"):
    loss1 = tf.nn.softmax_cross_entropy_with_logits(logits=pred1, labels=y)
    cost1 = tf.reduce_mean(loss1)

#training op for normal ConvNet
with tf.name_scope('train_op1'):
    train_op1 = tf.train.RMSPropOptimizer(learning_rate, 0.9).minimize(cost1)    

#logits for separable ConvNet
with tf.name_scope("forward_conv_sep"):
    pred2 = forward_conv_sep(x, weights2)

#Cost for separable ConvNet
with tf.name_scope("cost2"):
    loss2 = tf.nn.softmax_cross_entropy_with_logits(logits=pred2, labels=y)
    cost2 = tf.reduce_mean(loss2)

# training op for separable ConvNet
with tf.name_scope('train_op2'):
    train_op2 = tf.train.RMSPropOptimizer(learning_rate, 0.9).minimize(cost2)


with tf.name_scope('INIT'):
    init = tf.global_variables_initializer()


with tf.Session() as sess:

    sess.run(init)

    #train normal ConvNet for 2000 steps
    start = timeit.default_timer()
    for step in range(num_steps):
        r = np.random.choice(y_train.shape[0], batch_size, replace=False)
        batch_data = x_train[r]
        batch_labels = train_labels[r]

        feed_dict = {x : batch_data, y: batch_labels}
        _ , l = sess.run([train_op1,cost1], feed_dict=feed_dict)

    stop = timeit.default_timer()
    print ('training time for normal_conv after '+str(num_steps)+' step:',stop - start) 


    start = timeit.default_timer()
    feed_dict = {x : batch_data, y: batch_labels}
    predictions1 = sess.run(pred1, feed_dict=feed_dict)
    stop = timeit.default_timer()
    print ('time for normal_conv after one forward step: ',stop - start)



    # train separable ConvNet for 2000 steps
    start = timeit.default_timer()
    for step in range(num_steps):
        r = np.random.choice(y_train.shape[0], batch_size, replace=False)
        batch_data = x_train[r]
        batch_labels = train_labels[r]

        feed_dict = {x : batch_data, y: batch_labels}
        _ , l = sess.run([train_op2,cost2], feed_dict=feed_dict)


    stop = timeit.default_timer()
    print ('training time for sep_conv after '+str(num_steps)+' step:',stop - start) 

    start = timeit.default_timer()
    feed_dict = {x : batch_data, y: batch_labels}
    predictions = sess.run(pred2, feed_dict=feed_dict)
    stop = timeit.default_timer()
    print ('time for sep_conv after one forward step: ',stop - start)

Activity

cy89

cy89 commented on Sep 14, 2017

@cy89

@vrv my recollection is that there are some issues with backprop for separable convolutions on GPUs, which had been somewhat improved lately. Can you comment please on the state of the art?

vrv

vrv commented on Sep 14, 2017

@vrv

I'm not aware of any specific issues, kernels get faster as important models need them! In general, convolutions of different sizes will have different performance characteristics, and it's possible our separable convolution implementations may be slow for some combinations of shapes. I'm not sure whether that's the case here, but it could be. I also don't know whether theory matches practice here, since separable convolutions are less compute dense than normal convolutions. I believe the benefits you get are that you get to use fewer parameters to express a larger capacity convolution.

carlthome

carlthome commented on Sep 14, 2017

@carlthome
Contributor

At what kernel sizes will convolutions be computed via FFT instead of directly? Anyway, a speedup by doing a separable convolution is more noticeable for larger kernels, so for small kernels the overhead involved in doing two convolutions might be larger than the speedup, especially for what I assume is a highly-optimized convolution setting with the 3x3 kernel (Winograd).

Essentially, for a [m, n] kernel it would take m*n calculations for a convolution and m+n calculations for the separable convolution, if I'm not mistaken.

ghost

ghost commented on Oct 5, 2017

@ghost

@carlthome @vrv I do notice that when the number of filters get larger (128 or 192 etc. ), separable_conv2d is faster than conv2d, but in my case, I applied separable_conv2d on cifar10 with small number of filters, and it is actually slower than conv2d on my GPU, what could be the cause?

carlthome

carlthome commented on Oct 5, 2017

@carlthome
Contributor

Same principle. For small kernels matrices the overhead involved in doing two convolutions might be larger than the speedup.

ghost

ghost commented on Oct 5, 2017

@ghost

Understood, thanks for the explanation!

tensorflowbutler

tensorflowbutler commented on Dec 20, 2017

@tensorflowbutler
Member

It has been 14 days with no activity and the awaiting tensorflower label was assigned. Please update the label and/or status accordingly.

tensorflowbutler

tensorflowbutler commented on Jan 3, 2018

@tensorflowbutler
Member

It has been 14 days with no activity and the awaiting tensorflower label was assigned. Please update the label and/or status accordingly.

tensorflowbutler

tensorflowbutler commented on Jan 18, 2018

@tensorflowbutler
Member

Nagging Awaiting TensorFlower: It has been 14 days with no activity and the awaiting tensorflower label was assigned. Please update the label and/or status accordingly.

tensorflowbutler

tensorflowbutler commented on Jan 23, 2018

@tensorflowbutler
Member

A member of the TensorFlow organization has replied after the stat:awaiting tensorflower label was applied.

ghost closed this as completedon Jan 24, 2018

2 remaining items

Niels-Sch

Niels-Sch commented on Dec 9, 2018

@Niels-Sch

I am getting immense slowdown as well, is there a more efficient implementation known?

keunwoochoi

keunwoochoi commented on Aug 21, 2019

@keunwoochoi

Same problem. @HouseOfFinwe Did you figure out to fix this?

Niels-Sch

Niels-Sch commented on Aug 21, 2019

@Niels-Sch

@keunwoochoi I did not, please let me know if either of you managed.

keunwoochoi

keunwoochoi commented on Aug 21, 2019

@keunwoochoi

@HouseOfFinwe Ok thanks. tensorlayer/TensorLayer#416 (comment) says the speed issue was fixed in tf 1.5, but I'm having it with 1.13 now. Do you remember which versions have you tried?

Niels-Sch

Niels-Sch commented on Aug 21, 2019

@Niels-Sch

Sadly I do not

Uchiha-Shisui

Uchiha-Shisui commented on Oct 19, 2019

@Uchiha-Shisui

@HouseOfFinwe Ok thanks. tensorlayer/tensorlayer#416 (comment) says the speed issue was fixed in tf 1.5, but I'm having it with 1.13 now. Do you remember which versions have you tried?

Still exited on 1.14.

liuheng92

liuheng92 commented on Dec 25, 2019

@liuheng92

why close ??? still happen with tf1.13

byronyi

byronyi commented on Feb 10, 2020

@byronyi
Contributor

This should be fixed by #33836. It currently requires tf-nightly but will ship in the coming 2.2 release.

added a commit that references this issue on May 27, 2024

PR #12940: [ROCm] Fix dot_bf16.hlo.test on ROCm

added 2 commits that reference this issue on May 27, 2024

PR #12940: [ROCm] Fix dot_bf16.hlo.test on ROCm

PR #12940: [ROCm] Fix dot_bf16.hlo.test on ROCm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @vrv@carlthome@byronyi@keunwoochoi@liuheng92

      Issue actions

        tf.nn.separable_conv2d is slower than conv2d on GPU · Issue #12940 · tensorflow/tensorflow