У меня вопрос о новом API набора данных (tensorflow 1.4rc1). У меня несбалансированный набор данных по меткам 0
и 1
. Моя цель - создать сбалансированные мини-партии во время предварительной обработки.
Предположим, у меня есть два отфильтрованных набора данных:
ds_pos = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 1), []))
ds_neg = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 0), [])).repeat()
Есть ли способ объединить эти два набора данных так, чтобы результирующий набор данных выглядел как ds = [0, 1, 0, 1, 0, 1]
:
Что-то вроде этого:
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.apply(...)
# dataset looks like [0, 1, 0, 1, 0, 1, ...]
dataset = dataset.batch(20)
Мой текущий подход:
def _concat(x, y):
return tf.cond(tf.random_uniform(()) > 0.5, lambda: x, lambda: y)
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.map(_concat)
Но мне кажется, что есть более элегантный способ.
Заранее спасибо!