Код обратного распространения (во времени) в Tensorflow

Где я могу найти код обратного распространения (через время) в Tensorflow (API Python)? Или используются другие алгоритмы?

Например, когда я создаю сеть LSTM.


person Alex    schedule 20.04.2016    source источник


Ответы (2)


Все обратное распространение в TensorFlow реализовано путем автоматического дифференцирования операций в прямом проходе сети и добавления явных операций для вычисления градиента в каждой точке сети. Общую реализацию можно найти в tf.gradients(), но конкретная используемая версия зависит от того, как реализован ваш LSTM:

person mrry    schedule 20.04.2016
comment
Пожалуйста, взгляните на этот stackoverflow.com/q/66185202/14337775 и выскажите свое мнение - person Lawhatre; 17.02.2021

Я не уверен в этом, но это может сработать:

Поскольку RNN можно обучать так же, как сети с прямой связью, код очень похож. Вот как вы тренируете сеть с прямой связью: (X - вход)

train = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)

# Session
sess = tf.Session()
sess.run(tf.initialize_all_variables())

for i in range(epochs):
    sess.run(train, feed_dict={X: [[0, 0, 1], [1, 1, 1], [1, 0, 1], [0, 1, 1]], labels: [[0], [1], [1], [0]]})

Единственная разница в обратном распространении во времени состоит в том, что каждая эпоха теперь имеет вложенный цикл времени.

Это код для обучения простого rnn:

train = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)

time_series = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
for i in range(number_of_epochs):
    for j in range(len(time_series) - 1):
        curr_X = time_series[j+1]
        curr_prev = time_series[j]
        lbs = curr_prev
        sess.run(train, feed_dict={X: [[curr_X]], prev_val: [[curr_prev]], labels: [[lbs]]})

В этом коде rnn изучает временной ряд с альтернативными единицами и нулями.

person Paramdeep Singh Obheroi    schedule 30.05.2017