Description
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
1.Run my script below:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
# bug 1: bool type inputs
class Net_1(nn.Module):
def __init__(self, dropout=0.5):
super(Net_1, self).__init__()
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 10)
self.dropout = nn.Dropout(dropout)
def forward(self, x, use_dropout=False):
x = F.relu(self.fc1(x))
if use_dropout:
x = self.dropout(x) # or other operations ....
x = F.relu(self.fc2(x))
return x
with SummaryWriter("bugs") as w:
net = Net_1()
input_x = torch.randn((2,120))
w.add_graph(net, (input_x, True))
# bug 2: None type inputs (might be argument's default value)
class Net_2(nn.Module):
def __init__(self):
super(Net_2, self).__init__()
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(120, 84)
self.fc4 = nn.Linear(84, 10)
def forward(self, x, y=None, z=None):
x = F.relu(self.fc1(x))
if y is not None:
y = F.relu(self.fc2(y))
x = x + y
if z is not None:
z = F.relu(self.fc3(z))
x = x + z
x = F.relu(self.fc4(x))
return x
with SummaryWriter("bugs") as w:
net = Net_2()
input_x = torch.randn((2,120))
input_y = None
input_z = torch.randn((2,120))
w.add_graph(net, (input_x, input_y, input_z))
# bug 3: List type inputs (dict, or other python build-in types like int,str,... may also meet this question)
class Net_3(nn.Module):
def __init__(self):
super(Net_3, self).__init__()
self.fc_list = [nn.Linear(120, 120) for _ in range(10)]
self.fc_n = nn.Linear(120, 10)
def forward(self, x, index:list=None):
if index is not None:
for i in index:
x = F.relu(self.fc_list[i](x))
x = F.relu(self.fc_n(x))
return x
with SummaryWriter("bugs") as w:
net = Net_3()
input_x = torch.randn((2, 120))
index = [1, 5, 1, 7, 0]
w.add_graph(net, (input_x, index))
and you can see the trace(take bug 3 as an example):
Error occurs, No graph saved
Traceback (most recent call last):
File "<input>", line 1, in <module>
File "/Applications/PyCharm.app/Contents/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/wangyuanzheng/Downloads/xxxxxxx/project/albert_pytorch/dev/add_graph_bug.py", line 25, in <module>
w.add_graph(net, (input_x, True))
File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 682, in add_graph
self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 239, in graph
raise e
File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 234, in graph
trace = torch.jit.trace(model, args)
File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 858, in trace
check_tolerance, _force_outplace, _module_class)
File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 997, in trace_module
module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
RuntimeError: Type 'Tuple[Tensor, bool]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced (toTraceableIValue at ../torch/csrc/jit/pybind_utils.h:298)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x110c479e7 in libc10.dylib)
frame #1: torch::jit::toTraceableIValue(pybind11::handle) + 1280 (0x110246740 in libtorch_python.dylib)
frame #2: torch::jit::toTypedStack(pybind11::tuple const&) + 31 (0x1102e7edf in libtorch_python.dylib)
frame #3: void pybind11::cpp_function::initialize<torch::jit::script::initJitScriptBindings(_object*)::$_16, void, torch::jit::script::Module&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, pybind11::function, pybind11::tuple, pybind11::function, bool, pybind11::name, pybind11::is_method, pybind11::sibling>(torch::jit::script::initJitScriptBindings(_object*)::$_16&&, void (*)(torch::jit::script::Module&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, pybind11::function, pybind11::tuple, pybind11::function, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) + 147 (0x11031e4e3 in libtorch_python.dylib)
frame #4: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3372 (0x10fe57d3c in libtorch_python.dylib)
<omitting python frames>
Expected behavior
writer.add_graph should run normally.
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
Collecting environment information...
PyTorch version: 1.3.0
Is debug build: No
CUDA used to build PyTorch: None
OS: Mac OSX 10.14.6
GCC version: Could not collect
CMake version: Could not collect
Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] torch==1.3.0
[pip] torchvision==0.4.1
[conda] torch 1.3.0 pypi_0 pypi
[conda] torchvision 0.4.1 pypi_0 pypi
Additional context
1.TensorboardX.SummaryWriter.add_graph has the same bug as torch.utils.tensorboard.SummaryWriter.add_graph
2.Besides this bug, I hope add_graph could accept not only a tuple as positional arguments, but also a dict as keyword arguments for the model.forward()'s input
Activity
lanpa commentedon Oct 23, 2019
"Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced"
I think this behavior is by design. (To make the code compact and maintainable)
Here's are some workarounds:
Net1: w.add_graph(net, (input_x, torch.BoolTensor([True])))
Net2: replace input_y with boolean False, and apply the above.
Net3: wrap
index
with the torch tensor.I don't suggest visualizing a graph with a dynamic structure. Especially those with many if/else in the
forward
. This would result in cluttered visualization (Like the graph you would see in tensorflow's output.)ecolss commentedon Dec 12, 2020
Regarding your workarounds, what if one param should be passed as None?
lanpa commentedon Mar 19, 2021
Closing this since the graph visualization has been delegated to the PyTorch main repo. d7238f5 Please file a issue there if a special network input is needed when displaying the visualization.