Skip to content

Add a note in the docs about the momentum formulation used in optim #1099

@keskarnitish

Description

@keskarnitish
Contributor

I have been looking at the implementation of SGD + Momentum in PyTorch and noticed something a bit different from how other packages (and papers) describe it. For the moment, let's focus solely on (classical) momentum and not Nesterov's version.

At the time of writing, the implementation reads:

              if momentum != 0:
                   param_state = self.state[p]
                   if 'momentum_buffer' not in param_state:
                       buf = param_state['momentum_buffer'] = d_p.clone()
                   else:
                       buf = param_state['momentum_buffer']
                       buf.mul_(momentum).add_(1 - dampening, d_p)
                   if nesterov:
                       d_p = d_p.add(momentum, buf)
                   else:
                       d_p = buf

               p.data.add_(-group['lr'], d_p)

Mathematically, if we denote the momentum buffer by v and assume that dampening=0, at every iteration, the buffer is updated as v = m*v + g and the step is ∆x = lr * v. Notice that the learning rate lr hits the momentum term v as well as the gradient. To me, this is different from what classical momentum is, and also differs from how other packages implement SGD+M.

Let us contrast this with the Sutskever et. al. paper and other commonly used pacakges such as Lasagne, Keras, Neon, etc.

Sutskever et. al.

The snippet of the relevant section is pasted below.
Sutskever et. al.

Retaining the syntax from above, the algorithm updates v as v = m*v - lr * g with the step ∆x = v. So, the learning rate lr only hits the gradient. It does not (explicitly) influence the effect of the momentum term which is in contrast with PyTorch's implementation.

Lasagne

Lasagne employs the same rule as suggested in Sutskever for momentum.

    for param in params:
        value = param.get_value(borrow=True)
        velocity = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                                 broadcastable=param.broadcastable)
        x = momentum * velocity + updates[param]
        updates[velocity] = x - param

Keras

Same for Keras:

       for p, g, m in zip(params, grads, moments):
            v = self.momentum * m - lr * g  # velocity
            self.updates.append(K.update(m, v))

            if self.nesterov:
                new_p = p + self.momentum * v - lr * g
            else:
                new_p = p + v

Neon

and Neon.

                velocity[:] = self.momentum_coef * velocity - lrate * grad

                # Nesterov accelerated gradient (NAG) is implemented the same
                # as in torch's "sgd.lua". It's a reformulation of Sutskever's
                # NAG equation found in "On the importance of initialization
                # and momentum in deep learning".
                if self.nesterov:
                    param[:] = param + self.momentum_coef * velocity -\
                               lrate * grad
                else:
                    param[:] = param + velocity

Is the disparity true or am I missing something important?

The difference between the two implementations is not insignificant and especially so when lr is reduced along the way. If my claim is true, maybe we could update the reference (I'm not sure what that would be) or include the above version in the SGD code (I can take this up if necessary)?

Activity

colesbury

colesbury commented on Mar 25, 2017

@colesbury
Member

For a fixed learning rate, the two formulations are equivalent. The Torch formulation is chosen because the the step size is directly proportional to the learning rate. This means that if you decrease the learning rate, the step size decreases immediately, and not after some number of iterations, which is generally what you want.

changed the title [-]Implementation of SGD + Momentum[/-] [+]Add a note in the docs about the momentum formulation used in optim[/+] on Mar 25, 2017
keskarnitish

keskarnitish commented on Mar 25, 2017

@keskarnitish
ContributorAuthor

I agree. My only concern was that, given that the reference for the method is the Sutskever paper and there is no documentation to explain the difference, the current implementation could be a potential "gotcha" for folks moving to PyTorch from other frameworks.

soumith

soumith commented on Apr 5, 2017

@soumith
Member

@keskarnitish if you send a PR adding a note to the docs, I am happy to merge.

added a commit that references this issue on Apr 5, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @colesbury@soumith@apaszke@keskarnitish

      Issue actions

        Add a note in the docs about the momentum formulation used in optim · Issue #1099 · pytorch/pytorch