Skip to content

SummaryWriter.add_graph doesn't support non-tensor model inputs #520

Closed
@dalek-who

Description

@dalek-who

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

1.Run my script below:

import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter

# bug 1: bool type inputs
class Net_1(nn.Module):
    def __init__(self, dropout=0.5):
        super(Net_1, self).__init__()
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, use_dropout=False):
        x = F.relu(self.fc1(x))
        if use_dropout:
            x = self.dropout(x)  # or other operations ....
        x = F.relu(self.fc2(x))
        return x

with SummaryWriter("bugs") as w:
    net = Net_1()
    input_x = torch.randn((2,120))
    w.add_graph(net, (input_x, True))


# bug 2: None type inputs (might be argument's default value)
class Net_2(nn.Module):
    def __init__(self):
        super(Net_2, self).__init__()
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(120, 84)
        self.fc4 = nn.Linear(84, 10)

    def forward(self, x, y=None, z=None):
        x = F.relu(self.fc1(x))
        if y is not None:
            y = F.relu(self.fc2(y))
            x = x + y
        if z is not None:
            z = F.relu(self.fc3(z))
            x = x + z
        x = F.relu(self.fc4(x))
        return x

with SummaryWriter("bugs") as w:
    net = Net_2()
    input_x = torch.randn((2,120))
    input_y = None
    input_z = torch.randn((2,120))
    w.add_graph(net, (input_x, input_y, input_z))


# bug 3: List type inputs (dict, or other python build-in types like int,str,... may also meet this question)
class Net_3(nn.Module):
    def __init__(self):
        super(Net_3, self).__init__()
        self.fc_list = [nn.Linear(120, 120) for _ in range(10)]
        self.fc_n = nn.Linear(120, 10)

    def forward(self, x, index:list=None):
        if index is not None:
            for i in index:
                x = F.relu(self.fc_list[i](x))
        x = F.relu(self.fc_n(x))
        return x

with SummaryWriter("bugs") as w:
    net = Net_3()
    input_x = torch.randn((2, 120))
    index = [1, 5, 1, 7, 0]
    w.add_graph(net, (input_x, index))

and you can see the trace(take bug 3 as an example):

Error occurs, No graph saved
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Applications/PyCharm.app/Contents/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/PyCharm.app/Contents/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/wangyuanzheng/Downloads/xxxxxxx/project/albert_pytorch/dev/add_graph_bug.py", line 25, in <module>
    w.add_graph(net, (input_x, True))
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 682, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 239, in graph
    raise e
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 234, in graph
    trace = torch.jit.trace(model, args)
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 858, in trace
    check_tolerance, _force_outplace, _module_class)
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 997, in trace_module
    module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
RuntimeError: Type 'Tuple[Tensor, bool]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced (toTraceableIValue at ../torch/csrc/jit/pybind_utils.h:298)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x110c479e7 in libc10.dylib)
frame #1: torch::jit::toTraceableIValue(pybind11::handle) + 1280 (0x110246740 in libtorch_python.dylib)
frame #2: torch::jit::toTypedStack(pybind11::tuple const&) + 31 (0x1102e7edf in libtorch_python.dylib)
frame #3: void pybind11::cpp_function::initialize<torch::jit::script::initJitScriptBindings(_object*)::$_16, void, torch::jit::script::Module&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, pybind11::function, pybind11::tuple, pybind11::function, bool, pybind11::name, pybind11::is_method, pybind11::sibling>(torch::jit::script::initJitScriptBindings(_object*)::$_16&&, void (*)(torch::jit::script::Module&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, pybind11::function, pybind11::tuple, pybind11::function, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) + 147 (0x11031e4e3 in libtorch_python.dylib)
frame #4: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3372 (0x10fe57d3c in libtorch_python.dylib)
<omitting python frames>

Expected behavior

writer.add_graph should run normally.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

Collecting environment information...
PyTorch version: 1.3.0
Is debug build: No
CUDA used to build PyTorch: None
OS: Mac OSX 10.14.6
GCC version: Could not collect
CMake version: Could not collect
Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] torch==1.3.0
[pip] torchvision==0.4.1
[conda] torch 1.3.0 pypi_0 pypi
[conda] torchvision 0.4.1 pypi_0 pypi

Additional context

1.TensorboardX.SummaryWriter.add_graph has the same bug as torch.utils.tensorboard.SummaryWriter.add_graph
2.Besides this bug, I hope add_graph could accept not only a tuple as positional arguments, but also a dict as keyword arguments for the model.forward()'s input

Activity

lanpa

lanpa commented on Oct 23, 2019

@lanpa
Owner

"Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced"
I think this behavior is by design. (To make the code compact and maintainable)

Here's are some workarounds:
Net1: w.add_graph(net, (input_x, torch.BoolTensor([True])))
Net2: replace input_y with boolean False, and apply the above.
Net3: wrap index with the torch tensor.

I don't suggest visualizing a graph with a dynamic structure. Especially those with many if/else in the forward. This would result in cluttered visualization (Like the graph you would see in tensorflow's output.)

ecolss

ecolss commented on Dec 12, 2020

@ecolss

"Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced"
I think this behavior is by design. (To make the code compact and maintainable)

Here's are some workarounds:
Net1: w.add_graph(net, (input_x, torch.BoolTensor([True])))
Net2: replace input_y with boolean False, and apply the above.
Net3: wrap index with the torch tensor.

I don't suggest visualizing a graph with a dynamic structure. Especially those with many if/else in the forward. This would result in cluttered visualization (Like the graph you would see in tensorflow's output.)

Regarding your workarounds, what if one param should be passed as None?

lanpa

lanpa commented on Mar 19, 2021

@lanpa
Owner

Closing this since the graph visualization has been delegated to the PyTorch main repo. d7238f5 Please file a issue there if a special network input is needed when displaying the visualization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @lanpa@ecolss@dalek-who

        Issue actions

          SummaryWriter.add_graph doesn't support non-tensor model inputs · Issue #520 · lanpa/tensorboardX