Skip to content

Autograd refactor #1016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 1, 2017
Merged

Autograd refactor #1016

merged 8 commits into from
May 1, 2017

Conversation

apaszke
Copy link
Contributor

@apaszke apaszke commented Mar 16, 2017

Progress:

  • Change names in the backend to better match the new graph representation
  • Remove Function superclass from Variable. Replace them with special nodes in the graph
  • Implement the new function definition.
    • Make sure it's compatible with legacy definitions.
  • Make sure Variables are never unpacked in the Engine (partially done in the first commit)
  • Implement some of the built in Functions using the new format (to see if it's actually usable)
  • Add a function that returns a list of higher order grads

The tasks below are left for future PRs. The first two can be easily paralellized and can be done by others too.

  • Adapt all definitions of built in Functions (using @differentiable_once for now).
  • Implement all backward functions using Variables (so they can be differentiated multiple times)
  • Add deprecation warnings for old Function format
  • Add a switch for saving .creator graph traces

Summary of changes

New function definition format

Note that the old declarations are still supported - most of the core implementations are still not converted.

New format allows to implement jacobian vector products (jvp, L-op) of functions depending on jvp's of other functions (aka. grad of grad, hessian-vector products).

The new declarations look like this:

class MultiplyAdd(Function):
                                                            # 1.
    @staticmethod
    def forward(ctx, input1, scalar, input2):               # 2.
        ctx.scalar = scalar                                 # 3.
        return input1 + scalar * input2

    @staticmethod
    def backward(ctx, grad_output):                         # 4.
        return grad_output, None, ctx.scalar * grad_output  # 5.

Adnotations:

  1. Functions no longer can have an __init__ method. Think of them as pairs of pure functions that are formulas specifying how to compute the function and its jvp (Dfn * grad_output).
  2. Beacuse of 1., forward can now accept arguments of arbitrary types (used to only accept Variables). Any Variables appearing in args will be unpacked into Tensors. Arguments are not recursively searched. For example, a list of Variables won't be unpacked into a list of Tensors, and they won't be registered as inputs in the graph. Keyword arguments are not supported (need arg ordering to construct the graph).
  3. forward gets a ctx as a first argument - this is an object (of unspecified type - not an instance of this class) with an interface identical to self in old style definitions (save_for_backward, mark_non_differentiable, etc.) and is used to pass information to the backward call. For example, this function needs to save a scalar argument. Note that you shouldn't assign input or output tensors to it, however intermediate buffers are ok.
  4. grad_output is now a Variable, and the whole backward method needs to be implemented in terms of Variables (they shouldn't be unpacked into tensors, or the derivative graph will be malformed, see notes on @once_differentiable below). ctx will be the same object that was passed to forward.
  5. backward should return gradients for all arguments given to forward (even non-Variable arguments, but it should be None in such case). Unnecessary trailing Nones are still accepted (useful when forward has optional arguments).

For comparison, here's how a legacy definition of MultiplyAdd would look like:

class MultiplyAdd(Function):
    def __init__(self, scalar):
        super().__init__()
        self.scalar = scalar

    def forward(self, input1, scalar, input2):
        return input1 + self.scalar * input2

    def backward(self, grad_output):
        return grad_output, self.scalar * grad_output

@once_differentiable

The fact that backward now takes Variables might unnecessarily complicate implementations of custom function that e.g. call into other libs. For that reason, this PR also introduces a @once_differentiable decorator, that can be used to wrap backward. After adding it, backward functions will get a tensor grad_output and will be expected to return a grad input tensor for each tensor argument given in forward (and None for all other args).

class SciPyFunction(Function):
    @staticmethod
    def forward(ctx, input1, input2):
        return scipy.my_function(input1.numpy(), input2.numpy())

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        return scipy.my_function2(grad_output), scipy.my_function3(grad_output)

torch.autograd.backward

Added create_graph. If True, the graph for vjp will be created (defaults to False), allowing to differentiate the grad computation. Defaults to True if grad_variables contains at least one non-volatile Variable, and False otherwise.
Renamed retain_variables to retain_graph. The old argument will remain supported until v0.3, but will print deprecation warnings. If unspecified, defaults to the value of create_graph.

If grad_variables contains tensors, they are automatically promoted to Variables (volatile unless create_graph is True). Also, None entries in grad_variables are now accepted if their corresponding variables entries are scalar Variables (grad_output filled with 1 is allocated for them). Additionally, if all grad_variables could be None the argument is now optional.

torch.autograd.grad

While Chainer-style API is great for first order grads, it doesn't work nearly as well when computing higher order derivatives. See this example:

x = Variable(torch.randn(2, 2), requires_grad=True)
x.mul(2).sum().backward(create_graph=True)
y = x.grad.mul(2).sum()
# This accumulates grad of grad into x.grad, adding together results of both backward() calls
y.backward() 

For that reason, this PR also implements grad - a functional-style function that computes the vjp, and instead of accumulating it into .grad of all leaves, it returns a list of grads w.r.t. given function inputs (parameters are considered inputs too).

Example:

from torch.autograd import grad

x = Variable(torch.randn(2, 2), requires_grad=True)
x.mul(2).sum().backward(create_graph=True)
y = x.grad.mul(2).sum()
# The line below **doesn't change x.grad**
x_hv = grad(y, x) # grad of y w.r.t. x. This argument would be the grad_output, but y is scalar

Arguments outputs, inputs, grad_outputs arguments can be both sequences of Variables (or Tensors and Nones in case of grad_outputs), or single Variables.

If one doesn't request the grad w.r.t. all leaf Variables, unneeded gradients are not computed, and won't be accumulated into them (by default grad has no side effects). If only_inputs argument is set to False, the whole graph will be differentiated, grads w.r.t. inputs will be returned in a list and not accumulated into .grad, grads w.r.t. all other leaves will be accumulated into their .grad.

.grad semantics

By default the semantics are the same as right now. When not using any of the options implemented in this PR, .grad Variables will be volatile, and incoming grads will be accumulated in-place (both Variable and its .data will be the same objects - while we don't guarantee that, some people depend on that in their scripts, so it's best to support it unless there's no other way).
However, when using derivative graphs, these Variables will need to have their .grad_fn set correctly, and shouldn't be modified in-place (they might have been used in some functions!). For that reason in such cases the .grad attribute will point to a new Variable, with new .data, after each accumulation.

To sum up:

.grad New grad Action
volatile volatile Accumulated in-place into .grad
volatile - Accumulated in-place into .grad which remains volatile
- volatile New grad is converted to a Variable that doesn't require grad and added out-of-place. Result overwrites .grad
- - Added out-of-place. Result overwrites .grad

Implementation details

  • Variables no longer subclass Function and therefore can no longer appear in the graph. After this PR graphs contain AccumulateGrad nodes instead of Variables.
  • Engine now supports per-function callbacks that are called before evaluating the function, and if they return false the apply function won't be called (all gradients will default to null, which is an equivalent of 0), and its next_functions won't be added to the ready queue (unless they are already waiting for execution and this was their last dependency).

Sorry, something went wrong.

@apaszke apaszke changed the title Refactor attribute names in autograd [WIP] Refactor attribute names in autograd Mar 16, 2017
@apaszke apaszke closed this Mar 16, 2017
@apaszke apaszke reopened this Mar 16, 2017
@apaszke apaszke force-pushed the autograd branch 4 times, most recently from 92b1e02 to 940fe73 Compare March 17, 2017 20:46
@apaszke apaszke force-pushed the autograd branch 12 times, most recently from d5ef9b9 to 10b7285 Compare March 30, 2017 22:51
@apaszke apaszke changed the title [WIP] Refactor attribute names in autograd Refactor attribute names in autograd Mar 30, 2017
@apaszke apaszke changed the title Refactor attribute names in autograd Autograd refactor pt 1 Mar 30, 2017
@apaszke apaszke changed the title Autograd refactor pt 1 Autograd refactor Apr 1, 2017
attribute.

Arguments:
variables (sequence of Variable): outputs of the differentiated function.

This comment was marked as off-topic.

accumulated.
retain_variables (bool, optional): If True, buffers necessary for
computing the gradients won't be freed after use. It is only
necessary to speicfy True if you want to differentiate any subgraph

This comment was marked as off-topic.

else Variable(var, volatile=True)
for var in grad_outputs)
return Variable._execution_engine.run_backward(
tuple(outputs), tuple(grad_outputs), retain_variables,

This comment was marked as off-topic.



# TODO: how to do NoGrad in new style

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if repeat == 1:
continue
grad_input = sum(grad_input.chunk(repeat, dim))
return grad_input
return grad_input, None


class Cumsum(Function):

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -11,7 +11,7 @@ struct PyFunctionPreHook : public FunctionPreHook {
~PyFunctionPreHook();
variable_list operator()(const variable_list& grads) override;

This comment was marked as off-topic.


THPObjectPtr pyInputs = PyTuple_New(inputs.size());
if (!pyInputs) throw python_error();
auto num_inputs = inputs.size();

This comment was marked as off-topic.

// Returning too many results is ok, but only as long as they're all None
if (num_outputs > num_forward_inputs) {
bool all_none = true;
for (int i = num_outputs; i < num_forward_inputs; i++) {

This comment was marked as off-topic.

This comment was marked as off-topic.

#pragma once

// The InputBuffer class accumulates a list of Variables for use by a
// function. It implements logic to avoid modiyfing the passed

This comment was marked as off-topic.

};

auto Clone::apply(const variable_list& inputs) -> variable_list {
if (inputs.size() != 1) throw std::runtime_error("Add expects exactly 2 inputs");

This comment was marked as off-topic.

auto& fn = *task.fn;
auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs)));

auto& function_callbacks = task.base->function_callbacks;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@soumith
Copy link
Member

soumith commented Apr 6, 2017

some context for the refactor, see https://gist.github.com/apaszke/a8bc5f167ca4c0f3a830a23e296d3daf

@apaszke
Copy link
Contributor Author

apaszke commented May 2, 2017

It's been renamed to grad_fn. Are you sure you're using the version from master? If yes can you please send me a small script I could use to reproduce and fix the error?

@rishab96
Copy link

rishab96 commented May 3, 2017

@apaszke if I use the version from http://pytorch.org/ do I get the same version as master? For all 3 - previous_functions, next_functions and grad_fn I am getting 'ConvNdBackward' object has no attribute 'x'

@soumith
Copy link
Member

soumith commented May 3, 2017

@rishab96 the version from http://pytorch.org is not the same version as master. It's from the v0.1.12 branch.

@soumith
Copy link
Member

soumith commented May 3, 2017

to get the master branch, you have to compile from source according to instructions in README.md

@ChenRocks
Copy link

@apaszke @camigord I got a same error with LSTM Cell and I have filed an issue #1450

@camigord
Copy link

camigord commented May 3, 2017

Hi @apaszke, I was using the adagrad branch; the problem seems to be solved in master. Thanks for the replies.

@HiiYL
Copy link

HiiYL commented May 8, 2017

Im trying to implement WGAN-GP using the create_graph argument of torch.autograd.backward, however it throws the following error when backward() is called on the gradients:
RuntimeError: ConvBackward is not differentiable

My code is here:
https://github.com/HiiYL/WGAN-GP-PyTorch/blob/master/main.py#L224

@mjdietzx
Copy link
Contributor

mjdietzx commented May 8, 2017

I'm also trying to implemented wGAN with gradient penalty. I've taken a slightly different approach than @HiiYL.

# calculate `x_hat` where `x` is a batch of real images and `x_z` is a batch of generated images.
epsilon = torch.randn(batch_size, 1, 1, 1)
x_hat = epsilon.expand_as(x.data) * x.data + (1.0 - epsilon.expand_as(x_z.data)) * x_z.data
x_hat = autograd.Variable(x_hat).detach()

o = discriminator.forward(x_hat)
gradients = autograd.grad(o, x_hat)
gradient_penalty = autograd.Variable(10.0 * torch.pow(torch.norm(gradients, p=2) - 1.0, 2.0))
gradient_penalty.backward()

and gradients = autograd.grad(o, x_hat) throws RuntimeError: grad can be implicitly created only for scalar outputs

@lopezpaz
Copy link

lopezpaz commented May 8, 2017

@apaszke helped me code some minimal examples to compute second-order derivatives.

Perhaps helpful to @HiiYL @mjdietzx:

from torch.autograd import Variable, grad
import torch

x = Variable(torch.ones(1), requires_grad=True)
y = x.pow(3)

g = grad(y, x, create_graph=True)
print(g) # g = 3

g2 = grad(g, x)
print(g2) # g2 = 6

To implement the gradient penalty in WGAN:

import torch
from torch.autograd import Variable, grad

torch.manual_seed(0)

net = torch.nn.Linear(10,1)
mse = torch.nn.MSELoss()

x = Variable(torch.randn(128, 10), requires_grad=True)
y = Variable(torch.randn(128, 1))

# your normal loss computation goes here
net.zero_grad()
output = net(x)
loss = mse(output, y)
torch.autograd.backward(loss, create_graph=True)
update1 = net.weight.grad.data.clone()

# gradient penalization (effectively, second order derivative)
gradient_penalty = (grad(output.mean(), x, create_graph=True)[0].norm() - 1).pow(2)
gradient_penalty.backward() # this will be added to the grads w.r.t. the loss
update2 = net.weight.grad.data.clone()

print(update1)
print(update2)
print((update1-update2).norm())

@apaszke
Copy link
Contributor Author

apaszke commented May 8, 2017

@HiiYL the problem is that you're trying to compute the grad of grad of Convolution, but it's not implemented yet. This PR only added the machinery necessary to compute higher order derivatives, but didn't actually add support to the existing autograd functions.
@mjdietzx First, you should never call .forward directly on the model. Just do discriminator(x_hat). The error should tell you exactly what's wrong - output of discriminator is not a scalar, so you need to specify a third argument to autograd.grad. You can find the description in the docs.

@mjdietzx
Copy link
Contributor

mjdietzx commented May 8, 2017

After fixing my first problem based on your suggestions I also now get the same error RuntimeError: ConvBackward is not differentiable when calling backward() on the gradients. so I guess gradient penalty in wGAN wont be possible until support for higher order derivatives is adding to existing autograd functions? is that on the roadmap @apaszke ?

@apaszke
Copy link
Contributor Author

apaszke commented May 8, 2017

It is, but we need a while to adapt the codebase (and the conference season doesn't help with that).

@apaszke
Copy link
Contributor Author

apaszke commented May 8, 2017

Also, we're accepting PRs that adapt existing functions for higher order grads (I think there are 2 like that open at the moment, and a few already merged).

@apaszke apaszke deleted the autograd branch May 8, 2017 19:58
@caogang
Copy link
Contributor

caogang commented May 11, 2017

Hi @apaszke , is there any branch or contributor working on adding support for Variable backward in auto-generated THNN functions now?

@apaszke
Copy link
Contributor Author

apaszke commented May 11, 2017

@gchanan is going to work on it after he wraps up broadcasting, so at the moment there's no one working on it.

@liboyue
Copy link
Contributor

liboyue commented Jul 1, 2017

@apaszke The new code could only take a second-order derivative of a scalar. Could you please tel me if there is any way to calculate Hessian matrix now?

@soumith
Copy link
Member

soumith commented Jul 1, 2017

@liboyue you can only create a full hessian matrix with a for-loop calculating per scalar.

@liboyue
Copy link
Contributor

liboyue commented Jul 1, 2017

Thanks @soumith . I tried to do so but it is too slow. Do you have any plan to implement this function in the future?

@apaszke
Copy link
Contributor Author

apaszke commented Jul 1, 2017

Not really. The only thing we could do would be to batch the ops a bit more across scalars, but it won't give you a large speedup. Computing full hessian is very very expensive, which is why all these tricks for estimating it exist.

@liboyue
Copy link
Contributor

liboyue commented Jul 2, 2017

@apaszke You're right. I am just trying to use Hessian. I tried several auto diff packages, PyTorch is the fastest by far. But it is still too slow. Thanks :)

@lucasb-eyer
Copy link
Contributor

Hi all, the following code can be used to compute the full Hessian in a loop, as mentioned by @soumith. I'm wondering if anyone knows of a trick to compute the diagonal of the Hessian in a single pass? I wasn't able to come up with anything myself.

Code:

x = Variable(torch.FloatTensor([1,2]), requires_grad=True)
y = x[0].pow(2) * x[1]

dx, = grad(y, x, create_graph=True)
print(dx)  # (4,1)'

dx_dx1, = grad(dx, x, grad_outputs=torch.FloatTensor([1,0]), retain_graph=True)
dx_dx2, = grad(dx, x, grad_outputs=torch.FloatTensor([0,1]))

print(dx_dx1)  # (4,2)'
print(dx_dx2)  # (2,0)'

@apaszke
Copy link
Contributor Author

apaszke commented Jan 4, 2018

There's no way to do that in one go. You can only compute hessian vector products in this way, and there's no vector that will give you the diagonal when you multiply it by a matrix.

@lucasb-eyer
Copy link
Contributor

Thanks for the explanation, makes sense!

jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request Aug 5, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet