Skip to content

keras.layers.BatchNormalization update_ops not added #19643

Closed
@jackd

Description

@jackd
Contributor

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary): source
  • TensorFlow version (use command below): v1.8.0-1660-ga543d94
  • Python version: 2.7.12
  • Bazel version (if compiling from source): 0.13.0
  • GCC/Compiler version (if compiling from source): 4.9.3
  • CUDA/cuDNN version: 9.1 / 7.1
  • GPU model and memory: GeForce GTX 1070 , 8119MiB
  • Exact command to reproduce: See source code below

Describe the problem

tf.keras.layers.BatchNormalization() does not add any operations to UPDATE_OPS under certain version constraints.

Source code / logs

#!/usr/bin/python
import tensorflow as tf

graph = tf.get_default_graph()
tf.keras.backend.set_learning_phase(True)
features = tf.zeros(shape=(3, 10), dtype=tf.float32)
normed = tf.keras.layers.BatchNormalization()(features)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print('n_ops:        %d' % len(graph.get_operations()))
print('n_update_ops: %d' % len(update_ops))

Output using v1.8.0-1660-ga543d94, cuda 9.1, cudnn 7.1, installed from source (other details as above):

n_ops:        51
n_update_ops: 0

Output using v1.8.0-0-g93bc2e2072, cuda 9.0, cudnn 7.0, pip installed

n_ops:        41
n_update_ops: 2

Activity

jackd

jackd commented on May 30, 2018

@jackd
ContributorAuthor

Same behaviour with freshly pulled master build, tf.GIT_VERSION == v1.8.0-1660-ga543d94 (51 ops, 0 update ops).

bhack

bhack commented on May 31, 2018

@bhack
Contributor

Have you seen #16102?

jackd

jackd commented on May 31, 2018

@jackd
ContributorAuthor

Sounds similar, but I'm unsure if it's related. That's based on tf 1.4.1, but given that things work as expected in some versions of 1.8 (and from what I remember, 1.7) I'd be confused if they were the same. Are the keras versions of tensorflow shipped differently somehow? i.e. could we be using the same keras implementations under-the-hood despite different tf versions? Or has there been a recent rollback on certain aspects of the code base which could undo a bug fix?

Just confirmed that it's unlikely anything to do with CUDA/cudnn (don't see how it could have, but... you never know). Same behaviour (51 ops, 0 control ops) with 9.0 / 7.0 / v1.8.0-2594-g25b2f01

assigned and unassigned on Jun 4, 2018
fchollet

fchollet commented on Jun 4, 2018

@fchollet
Contributor

That is by design. Global collections (and global states in general) are prone to various issues (like most global variables in software engineering) and should be avoided.

You can retrieve the updates created by your BatchNormalization layer via:

updates = layer.updates

Additionally you can retrieve the aggregated updates of all stateful layers in a Keras model via:

updates = model.updates

If you are writing your own custom training loops, you will have to run these updates as part of the call to sess.run(). But do note that tf.keras provides a built-in training loop in the form of model.fit(), as well as primitive methods for building your own custom training loops (in particular model.train_on_batch() and model.predict_on_batch()). If you are using Keras you should generally never have to manually manage your model's updates.

jackd

jackd commented on Jun 5, 2018

@jackd
ContributorAuthor

... ... ... you're obviously entitled to make design choices such as this, but for what it's worth: I think this is a terrible decision. I'm no fan of the collections system used by tensorflow, but changing it in this subpackage is:

  • inconsistent with the rest of tensorflow;
  • inconsistent with previous tf.keras versions - even within 1.8.0; and
  • inconsistent with other usages of collections by tf.keras.layers (see code below)
import tensorflow as tf
x = tf.zeros(shape=(2, 3), dtype=tf.float32)
y = tf.keras.layers.Dense(4)(x)
print('variables: %d' % len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))

I can appreciate as a separate package it may be convenient to force people to do things only 'the keras way', and I'd be less opposed if this issue was in the main keras repository. However, keras exists within tensorflow as well as external to it, so I'd expect it to follow its conventions. This change is breaking for anyone who uses non-keras session management (e.g. estimators) - myself included. I'd love to be able to mix and match tensorflow and keras - mostly to use other people's code, and take advantage of the excellent tf.keras.applications - but decisions like this make me more and more inclined to purge any reference to it, since there seems to be no guarantee about forward compatibility.

bhack

bhack commented on Jun 5, 2018

@bhack
Contributor

@jackd You have mentioned estimators. Do you think it is involved also in #17950?

jackd

jackd commented on Jun 5, 2018

@jackd
ContributorAuthor

If this is a conscious design choice as mentioned above, I highly doubt the estimator framework will cause the update ops to run, so this should probably be dealt with by keras's own mechanism during conversion. As far as I can tell, it isn't.

I'm not overly familiar with the keras code base but I can't find any explicit reference to model.updates, and the place where I suspect they're meant to be collected in they just aren't.

from keras/engine/training.py, Model._make_train_function:

        with K.name_scope(self.optimizer.__class__.__name__):
          # Training updates
          updates = self.optimizer.get_updates(
              params=self._collected_trainable_weights, loss=self.total_loss)
        # Unconditional updates
        updates += self.get_updates_for(None)
        # Conditional updates relevant to this model
        updates += self.get_updates_for(self._feed_inputs)
        # Stateful metrics updates
        updates += self.metrics_updates
        # Gets loss and metrics. Updates weights at each call.

It's not an unconditional update (see below), and I doubt its a metric update or an update required for feed inputs. They're just not added to update_ops.

import tensorflow as tf

tf.keras.backend.set_learning_phase(True)
input_shape = (3,)
inp = tf.keras.Input(shape=input_shape, dtype=tf.float32)
x = tf.keras.layers.Dense(4, input_shape=input_shape)(inp)
x = tf.keras.layers.BatchNormalization()(x, training=True)

model = tf.keras.Model(inp, x)
model.compile(tf.train.AdamOptimizer(1e-3), 'mean_squared_error')
print('model_updates: %d' % len(model.updates))
for update in model.updates:
    print(update.name, update._unconditional_update)  # False

estimator = tf.keras.estimator.model_to_estimator(model)

z = tf.zeros((2, 3), dtype=tf.float32)
labels = tf.zeros((2, 4), dtype=tf.float32)

spec = estimator.model_fn(z, labels, mode='train', config=None)

print(len(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))  # 0
bhack

bhack commented on Jun 12, 2018

@bhack
Contributor

@jackd Yes I have the same suspect that it was related to the converter API not handling update_ops correctly. This now it is closed but #17950 is still open. I've also added there a reference to the upstream keras-team/keras#9214. If you want subscribe to the #17950 thread.

tanzhenyu

tanzhenyu commented on Jul 3, 2018

@tanzhenyu
Contributor

@jackd the update ops should be created through conditional updates from _feed_inputs (correct me if I'm wrong)

jchia

jchia commented on Nov 20, 2018

@jchia
Contributor

You can retrieve the updates created by your BatchNormalization layer via:

updates = layer.updates

@fchollet Design philosophy aside, I'm not sure how this works. Under 1.12, the following code prints the empty list for the first print and raises an AttributeError on the second print, because a Tensor does not have an updates attribute. According to your prescribed approach, I should be getting a non-empty list of update ops from bn.updates. How's it supposed to work? How should the code be modified?

#!/usr/bin/env python3

import tensorflow as tf
import tensorflow.keras.layers as kl

def main():
    g = tf.Graph()
    with g.as_default():
        layer = tf.placeholder(name='x', shape=(4, 4), dtype=tf.float32)
        bn = kl.BatchNormalization()
        print(bn.updates)
        layer = bn(layer, training=True)
        print(layer.updates)

main()
jackd

jackd commented on Nov 20, 2018

@jackd
ContributorAuthor

@jchia change print(layer.updates) to print(bn.updates) :)

jchia

jchia commented on Nov 20, 2018

@jchia
Contributor

@jchia change print(layer.updates) to print(bn.updates) :)

Thanks, that works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    No branches or pull requests

      Participants

      @skye@jchia@jackd@fchollet@bhack

      Issue actions

        keras.layers.BatchNormalization update_ops not added · Issue #19643 · tensorflow/tensorflow