Skip to content

Update moving_mean and moving_variance of BatchNormalization Layer use sess.run() #6752

Closed
@zzd1992

Description

@zzd1992

I usually just use the layers of keras. I write the training code by myself in tensorflow.
I find if I don't use model.fit function to train a model, moving_mean and moving_variance of BatchNormalization Layer will not update. That is, moving_mean is always equal to 0 and moving_variance is always equal to 1.
Here is a example of my model:

import keras
import tensorflow as tf
import keras.layers as kl
import keras.backend as K
import numpy as np

K.set_learning_phase(1)
model = keras.models.Sequential()
model.add(kl.InputLayer([784]))
model.add(kl.Dense(400))
model.add(kl.normalization.BatchNormalization())
model.add(kl.Activation('relu'))
model.add(kl.Dense(400))
model.add(kl.normalization.BatchNormalization())
model.add(kl.Activation('relu'))
model.add(kl.Dense(10,activation='sigmoid'))

When I use model.fit to train it, moving_mean and moving_variance are updated.

model.compile(loss='categorical_crossentropy',optimizer=keras.optimizers.Adam())
model.fit(x,y,500,1)

But when I train it use original tensorflow code like the following:

train = tf.train.AdamOptimizer(0.001).minimize(loss,var_list=model.weights)
_train, err = sess.run([train,loss],{img:a,label:b})

In this way, moving_mean and moving_variance are not updated.
I know we can see moving_mean and moving_variance in model.updates. But I don't know how to update them during training if I don't want to use model.fit.
Is there a simple solution?

Activity

zzd1992

zzd1992 commented on May 25, 2017

@zzd1992
Author

I mean how to train moving_mean and moving_variance directly using sess.run().

zzd1992

zzd1992 commented on May 26, 2017

@zzd1992
Author
changed the title [-]Update moving_mean and moving_variance of BatchNormalization Layer without use of model.fit function[/-] [+]Update moving_mean and moving_variance of BatchNormalization Layer use sess.run()[/+] on May 27, 2017
pkern90

pkern90 commented on Aug 22, 2017

@pkern90

@zzd1992, you might wanna take a look at the section "Collecting trainable weights and state updates" of https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html.

Collecting trainable weights and state updates

Some Keras layers (stateful RNNs and BatchNormalization layers) have internal updates that need to be run as part of each training step. There are stored as a list of tensor tuples, layer.updates. You should generate assign ops for those, to be run at each training step. Here's an example:

from keras.layers import BatchNormalization

layer = BatchNormalization()(x)

update_ops = []
for old_value, new_value in layer.updates:
    update_ops.append(tf.assign(old_value, new_value))

Note that if you are using a Keras model (Model instance or Sequential instance), model.udpates behaves in the same way (and collects the updates of all underlying layers in the model).

YoelShoshan

YoelShoshan commented on Oct 27, 2017

@YoelShoshan

Actually, I think that there's a small mistake in that tutorial, because "layer" there is just a tf tensor.
(Or maybe it's just not updated)

You need to change it into something like this:

from keras.layers import BatchNormalization

layer = BatchNormalization()

blah = layer(x)

update_ops = []
for old_value, new_value in layer.updates:
    update_ops.append(tf.assign(old_value, new_value))

also, it seems that layer.updates already contains the assign ops - so further change is needed into this:

from keras.layers import BatchNormalization

layer = BatchNormalization()

blah = layer(x)

update_ops = []
for assign_op in layer.updates:
    update_ops.append(assign_op))

please correct me if I'm wrong :)

alpapado

alpapado commented on Jul 9, 2018

@alpapado

Is it possible to get a complete example of the above solution in order to clarify when the update_ops should be called?

For instance, given the code of zzd1992 in the first post and the proposed solution, a training step would be run using

_train, err = sess.run([train, loss, update_ops],{img:a,label:b})

or do the update_ops need to be called separately?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @fchollet@YoelShoshan@pkern90@alpapado@zzd1992

        Issue actions

          Update moving_mean and moving_variance of BatchNormalization Layer use sess.run() · Issue #6752 · keras-team/keras