Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unable to load model #15

Closed
lazysquid opened this issue Nov 9, 2017 · 1 comment
Closed

unable to load model #15

lazysquid opened this issue Nov 9, 2017 · 1 comment

Comments

@lazysquid
Copy link

Hello. I trained my dataset with your segnet implementation.
I used DePool2D custom layer when I trained my model.
Training was done smoothly, and it shows nice results.

However when I tried to reload the model it shows this error and I can't figure out what the problem is. Do you have any tips?

Traceback (most recent call last):
  File "infer_test.py", line 188, in <module>
    main()
  File "infer_test.py", line 183, in main
    join(model_root, "infer_results", splitext(snapshot)[0]))
  File "infer_test.py", line 91, in infer_data
    model = load_model(snapshot_path, custom_objects = {'DePool2D' : DePool2D(MaxPooling2D)})
  File "/usr/lib/python3.6/site-packages/keras/models.py", line 239, in load_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "/usr/lib/python3.6/site-packages/keras/models.py", line 313, in model_from_config
    return layer_module.deserialize(config, custom_objects=custom_objects)
  File "/usr/lib/python3.6/site-packages/keras/layers/__init__.py", line 54, in deserialize
    printable_module_name='layer')
  File "/usr/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
    list(custom_objects.items())))
  File "/usr/lib/python3.6/site-packages/keras/engine/topology.py", line 2487, in from_config
    process_layer(layer_data)
  File "/usr/lib/python3.6/site-packages/keras/engine/topology.py", line 2473, in process_layer
    custom_objects=custom_objects)
  File "/usr/lib/python3.6/site-packages/keras/layers/__init__.py", line 54, in deserialize
    printable_module_name='layer')
  File "/usr/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 141, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "/usr/lib/python3.6/site-packages/keras/engine/topology.py", line 1252, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'pool2d_layer'

I load model with this codes

from build_segnet import DePool2D
...
model = load_model(snapshot_path, custom_objects = {'DePool2D' : DePool2D()})
@lazysquid
Copy link
Author

I guess this problem is similar to keras-team/keras#5815

I solve this issue with two different approach

Approach 1

model = create_segnet((256, 256, 3), 3, indices=True, ker_init="he_normal")
model.load_weights(snapshot_path)

Approach 2
Change init of DePool2D

    def __init__(self, pool2d_layer= MaxPooling2D, *args, **kwargs):
        self._pool2d_layer = pool2d_layer
        super().__init__(*args, **kwargs)

But I think I'm missing fundamental solution.
Is there anyone who have more elegant solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant