Skip to content

Dataset Iterator is not an iterator #15273

Closed
@ruuda

Description

@ruuda

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): confidential
  • TensorFlow installed from (source or binary): Pypi
  • TensorFlow version (use command below): v1.4.0-19-ga52c8d9 1.4.1
  • Python version: 3.5.3
  • CUDA/cuDNN version: (sensitive information replaced by xxx)
$ apt search cud | grep installed
libcublas8.0/xxx,now 8.0.44-4 amd64 [installed]
libcuda1/xxx,now 375.66-1 amd64 [installed,automatic]
libcuda1-i386/xxx,now 375.66-1 i386 [installed,automatic]
libcudart8.0/xxx,now 8.0.44-4 amd64 [installed]
libcudnn6/now 6.0.21-1+cuda8.0 amd64 [installed,local]
libcufft8.0/xxx,now 8.0.44-4 amd64 [installed]
libcurand8.0/xxx,now 8.0.44-4 amd64 [installed]
libnvidia-fatbinaryloader/xxx,now 375.66-1 amd64 [installed,automatic]
libnvidia-ptxjitcompiler/xxx,now 375.66-1 amd64 [installed,automatic]
  • GPU model and memory: Quadro K1200, 4019 MiB
  • Exact command to reproduce:
    Run convert_to_records.py from the official MNIST example, then:
>>> import tensorflow as tf
>>> ds = tf.data.TFRecordDataset(['/tmp/mnist_data'])
>>> i  = ds.make_one_shot_iterator()
>>> next(i)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'Iterator' object is not an iterator

Describe the problem

The returned "iterator" is not an iterator, because it does not provide a __next__ or next method. It does provide a get_next method, but that is not what Python expects.

Activity

snnn

snnn commented on Dec 11, 2017

@snnn
Contributor

I think it is expected. You can't write code like:

for item in ds.make_one_shot_iterator()
   print(item)
ruuda

ruuda commented on Dec 11, 2017

@ruuda
Author

Is there a technical reason that prevents Iterator from being an iterator? It looks like the iterator does something unexpected, and there is already a warning for that. But as far as I can see, nothing prevents a next() and __next__() method that simply call get_next().

ruuda

ruuda commented on Dec 11, 2017

@ruuda
Author

CC @mrry, who is the author of the note.

mrry

mrry commented on Dec 11, 2017

@mrry
Contributor

Is there a technical reason that prevents Iterator from being an iterator?

Yes. In TensorFlow (except Eager mode), we use the tf.data.Iterator to get symbolic tf.Tensor objects that can be chained together with other operations to build a dataflow graph. We typically build the graph once, and use it many times. Wrapping the graph construction in a for loop (by using the tf.data.Iterator as a Python iterator) would not be efficient. Furthermore, tf.data.Iterator is implemented using TensorFlow operations, so you would need to provide a tf.Session to run these operations, and the Python iterator protocol doesn't provide a way to do that.

In eager mode, you can wrap a tf.data.Dataset in a tf.contrib.eager.Iterator, which is usable as a Python iterator. This definitely seems to be a more natural use of the tf.data API, and we'd encourage you to try it out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @mrry@ruuda@snnn@angerson

        Issue actions

          Dataset Iterator is not an iterator · Issue #15273 · tensorflow/tensorflow