Для ясности рассмотрим два случая.
Случай 1: Простая модель и
Случай 2: Сложная модель, в которой использовались определяемые пользователем классы, унаследованные от tf.keras.Model
.
Случай 1: Простая модель (как в функциональных и последовательных моделях keras)
Когда вы сохраняете веса модели (используя model.save_weights
), а затем загружаете веса (используя model.load_weights
), по умолчанию метод load_weights
использует топологическую загрузку. Это то же самое для формата Tensorflow saved_model ('tf'), а также для формата 'h5'. Например,
loadedh5_model.load_weights('./MyModel_h5.h5')
# the line above is same as the line below (as second and third arguments are default)
#loadedh5_model.load_weights('./MyModel_h5.h5',by_name=False, skip_mismatch=False)
В случае, если вы хотите загрузить веса определенных слоев сохраненной модели, вам необходимо использовать by_name=True
. Существуют варианты использования, требующие такого типа загрузки.
loadedh5_model.load_weights('./MyModel_h5.h5',by_name=True, skip_mismatch=False)
Случай 2: Сложная модель (как в моделях подкласса Keras)
На данный момент поддерживается только формат tf, только если при создании модели использовались определенные пользователем классы, унаследованные от tf.keras.Model
.
При загрузке весов из формата TensorFlow поддерживается только топологическая загрузка (by_name = False). Обратите внимание, что топологическая загрузка форматов TensorFlow и HDF5 немного отличается для определяемых пользователем классов, унаследованных от tf.keras.Model: HDF5 загружается на основе сглаженного списка весов, в то время как формат TensorFlow загружается на основе локальных имен объектов атрибутов, для которых слои назначаются в конструкторе модели.
Основная причина в том, что веса имеют формат h5
и формат tf
. Например, рассмотрим Case 1
, где HDF5 загружается на основе упорядоченного списка весов. Вес загружается без ошибок. Однако в Case 2
модель имеет user defined classes
подход, который требует другого подхода, чем просто загрузка плоских весов. Чтобы позаботиться о назначении весов пользовательских классов, формат tf загружает веса на основе локальных имен объектов атрибутов, которым слои назначаются в конструкторе модели.
Следующий абзац, упомянутый на веб-сайте keras, дополнительно разъясняет
При загрузке файла веса в формате TensorFlow возвращает тот же объект статуса, что и tf.train.Checkpoint.restore. При построении графа операции восстановления запускаются автоматически, как только сеть построена (при первом вызове определяемых пользователем классов, наследующих от модели, немедленно, если она уже построена).
Еще один момент, который следует понять, - это модели keras Functional
или Sequential
- это статические графики слоев, которые могут без проблем использовать сглаженные веса. Модель Keras с подклассом (как в нашем случае 2) представляет собой фрагмент кода Python (метод вызова). Графика слоев нет. Поэтому, как только сеть построена с использованием настраиваемых классов, запускаются операции восстановления для обновления объектов состояния. Надеюсь, это поможет.
person
Vishnuvardhan Janapati
schedule
10.05.2020