Лучший способ сохранить обученную модель в PyTorch?

Искал альтернативные способы сохранить обученную модель в PyTorch. Пока что я нашел две альтернативы.

  1. torch.save (), чтобы сохранить модель и torch. load (), чтобы загрузить модель.
  2. model.state_dict () для сохранения обученной модели и model.load_state_dict (), чтобы загрузить сохраненную модель.

Я наткнулся на это обсуждение, где подход 2 рекомендуется, а не подход 1.

У меня вопрос: почему предпочтение отдается второму подходу? Это только потому, что модули torch.nn имеют эти две функции, и нам рекомендуется их использовать?


person Wasi Ahmad    schedule 09.03.2017    source источник
comment
Я думаю, это потому, что torch.save () также сохраняет все промежуточные переменные, такие как промежуточные выходные данные для использования обратного распространения. Но вам нужно только сохранить параметры модели, такие как вес / смещение и т. Д. Иногда первые могут быть намного больше, чем вторые.   -  person Dawei Yang    schedule 18.03.2017
comment
Я тестировал torch.save(model, f) и torch.save(model.state_dict(), f). Сохраненные файлы имеют одинаковый размер. Теперь я запутался. Кроме того, я обнаружил, что использование pickle для сохранения model.state_dict () очень медленное. Я думаю, что лучший способ - использовать torch.save(model.state_dict(), f), поскольку вы занимаетесь созданием модели, а torch обрабатывает загрузку весов модели, тем самым устраняя возможные проблемы. Ссылка: Discussion.pytorch.org/t/saving-torch-models/ 838/4   -  person Dawei Yang    schedule 29.03.2017
comment
Похоже, PyTorch рассмотрел этот вопрос более подробно в своем разделе руководств - там много хорошая информация, которая не указана в ответах здесь, в том числе сохранение более одной модели за раз и модели с теплым запуском.   -  person whlteXbread    schedule 25.03.2019
comment
что не так с использованием pickle?   -  person Charlie Parker    schedule 13.07.2020
comment
@CharlieParker torch.save основан на рассоле. Следующее взято из учебника, ссылка на который приведена выше: [torch.save] сохранит весь модуль, используя модуль pickle Python. Недостатком этого подхода является то, что сериализованные данные привязаны к определенным классам и точной структуре каталогов, используемой при сохранении модели. Причина этого в том, что pickle не сохраняет сам класс модели. Скорее, он сохраняет путь к файлу, содержащему класс, который используется во время загрузки. Из-за этого ваш код может ломаться по-разному при использовании в других проектах или после рефакторинга.   -  person David Miller    schedule 14.07.2020
comment
@DavidMiller на самом деле мне нужно только сохранить nn.Sequential model. Вы знаете, как это сделать? У меня нет определения класса модели. Для последовательного, я написал это, надеюсь, уважаемый ответчик подтвердит: stackoverflow.com/questions/62923052/   -  person Charlie Parker    schedule 15.07.2020


Ответы (7)


Я нашел эту страницу на их github. репо, я просто вставлю сюда содержимое.


Рекомендуемый подход для сохранения модели

Есть два основных подхода к сериализации и восстановлению модели.

Первый (рекомендуемый) сохраняет и загружает только параметры модели:

torch.save(the_model.state_dict(), PATH)

Тогда позже:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Второй сохраняет и загружает всю модель:

torch.save(the_model, PATH)

Тогда позже:

the_model = torch.load(PATH)

Однако в этом случае сериализованные данные привязаны к конкретным классам и конкретной используемой структуре каталогов, поэтому они могут ломаться по-разному при использовании в других проектах или после серьезных рефакторингов.

person dontloo    schedule 06.05.2017
comment
Согласно @smth Discussion.pytorch. org / t / save-and-loading-a-model-in-pytorch / модель перезагружается для обучения модели по умолчанию. поэтому необходимо вручную вызвать the_model.eval () после загрузки, если вы загружаете его для вывода, а не возобновления обучения. - person WillZ; 16.07.2018
comment
второй метод дает stackoverflow.com / questions / 53798009 / ошибка в Windows 10. не удалось ее решить - person Gulzar; 16.12.2018
comment
Есть ли возможность сохранить без доступа к классу модели? - person Michael D; 11.12.2019
comment
При таком подходе, как вы отслеживаете * args и ** kwargs, которые вам нужно передать для случая нагрузки? - person Mariano Kamp; 09.04.2020
comment
на самом деле мне нужно сохранить только nn.Sequential модель. Вы знаете, как это сделать? У меня нет определения класса модели. - person Charlie Parker; 15.07.2020
comment
@CharlieParker, извините, не знаю, слишком долго не работал над pytorch - person dontloo; 15.07.2020
comment
@dontloo the_model = TheModelClass (* args, ** kwargs). Выполнение этой команды сообщает NameError: имя TheModelClass не определено. Как мне это сделать - person Naren Babu R; 20.07.2020
comment
@NarenBabuR, это было фиктивное имя. Вы должны заменить его фактическим классом модели, который вы создали, или какой бы то ни было предопределенной моделью, которую вы использовали в PyTorch. Аргументы и kwargs - это все, что вы использовали для определения модели. Вы можете сохранить как веса, так и смещения, а также дополнительные параметры, необходимые для загрузки модели. См. Второй ответ ниже. - person rayryeng; 11.04.2021
comment
Я сохранил модель rnn вторым методом. После загрузки я хочу делать прогнозы и рассчитывать такие показатели, как F1. Однако предсказание не работает. Я получаю эту ошибку IndexError: index out of range in self при запуске этого: predictions = model(batch.textt).squeeze(1) - person mah65; 24.04.2021

Это зависит от того, чем вы хотите заниматься.

Случай № 1. Сохраните модель, чтобы использовать ее для вывода: вы сохраняете модель, восстанавливаете ее, а затем переводите модель в режим оценки. Это сделано потому, что обычно у вас есть слои BatchNorm и Dropout, которые по умолчанию находятся в режиме обучения при строительстве:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Случай № 2: Сохраните модель, чтобы продолжить обучение позже. Если вам нужно продолжить обучение модели, которую вы собираетесь сохранить, вам нужно сохранить не только модель. Вам также необходимо сохранить состояние оптимизатора, эпохи, счет и т. Д. Вы бы сделали это так:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Чтобы возобновить обучение, вы должны сделать что-то вроде: state = torch.load(filepath), а затем, чтобы восстановить состояние каждого отдельного объекта, что-то вроде этого:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Поскольку вы возобновляете обучение, НЕ вызывайте model.eval() после восстановления состояний при загрузке.

Случай № 3: Модель будет использоваться кем-то другим, не имеющим доступа к вашему коду: в Tensorflow вы можете создать .pb файл, который определяет как архитектуру, так и веса модели. Это очень удобно, особенно при использовании Tensorflow serve. Эквивалентный способ сделать это в Pytorch:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Этот способ все еще не является пуленепробиваемым, и, поскольку pytorch все еще претерпевает множество изменений, я бы не рекомендовал его.

person Jadiel de Armas    schedule 02.03.2018
comment
Есть ли рекомендуемый конец файла для 3 случаев? Или это всегда .pth? - person Verena Haunschmid; 12.02.2019
comment
В случае № 3 torch.load возвращает просто OrderedDict. Как получить модель, чтобы делать прогнозы? - person Alber8295; 12.02.2019
comment
Привет, Могу ли я узнать, как сделать упомянутый случай № 2: Сохранить модель, чтобы продолжить обучение позже? Мне удалось загрузить контрольную точку в модель, затем я не смог запустить или возобновить обучение модели, такой как model.to (устройство) model = train_model_epoch (модель, критерий, оптимизатор, расписание, эпохи) - person dnez; 08.03.2019
comment
Привет, для случая, который предназначен для вывода, в официальном документе pytorch говорится, что он должен сохранять оптимизатор state_dict либо для вывода, либо для завершения обучения. При сохранении общей контрольной точки, которая будет использоваться либо для вывода, либо для возобновления обучения, вы должны сохранить больше, чем просто state_dict модели. Также важно сохранить state_dict оптимизатора, поскольку он содержит буферы и параметры, которые обновляются по мере обучения модели. - person Mohammed Awney; 21.09.2019
comment
В случае № 3 класс модели должен быть где-то определен. - person Michael D; 11.12.2019
comment
Другой вопрос касается некоторых беспорядков. дела №3. Я написал аналогичную заставку для модели для Keras, и этот процесс сохранения исходного кода модели по своей сути беспорядочный. - person Josiah Yoder; 28.08.2020
comment
Для варианта использования № 3 вам, вероятно, понадобится формат обмена моделями, такой как ONNX, а не копирование всей модели. - person Nzbuu; 09.05.2021

Библиотека Python pickle реализует двоичные протоколы для сериализации и десериализации объекта Python. .

Когда вы import torch (или когда вы используете PyTorch), он будет import pickle для вас, и вам не нужно напрямую вызывать pickle.dump() и pickle.load(), которые являются методами для сохранения и загрузки объекта.

Фактически, torch.save() и torch.load() обернут для вас pickle.dump() и pickle.load().

state_dict другой упомянутый ответ заслуживает еще нескольких заметок.

Что state_dict у нас внутри PyTorch? На самом деле есть два state_dict.

Модель PyTorch torch.nn.Module, которая имеет model.parameters() вызов для получения обучаемых параметров (w и b). Эти изучаемые параметры, однажды установленные случайным образом, будут обновляться с течением времени по мере нашего обучения. Обучаемые параметры - это первые state_dict.

Второй state_dict - это определение состояния оптимизатора. Вы помните, что оптимизатор используется для улучшения наших обучаемых параметров. Но оптимизатор state_dict исправлен. Здесь нечему учиться.

Поскольку state_dict объекты являются словарями Python, их можно легко сохранять, обновлять, изменять и восстанавливать, добавляя большую модульность моделям и оптимизаторам PyTorch.

Давайте создадим суперпростую модель, чтобы объяснить это:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Этот код выведет следующее:

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Обратите внимание, это минимальная модель. Вы можете попробовать добавить стопку последовательных

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Обратите внимание, что только слои с обучаемыми параметрами (сверточные слои, линейные слои и т. Д.) И зарегистрированными буферами (слои батчнорма) имеют записи в state_dict модели.

Необучаемые вещи принадлежат объекту оптимизатора state_dict, который содержит информацию о состоянии оптимизатора, а также об используемых гиперпараметрах.

В остальном история такая же; на этапе вывода (это этап, когда мы используем модель после обучения) для прогнозирования; мы делаем прогнозы на основе изученных нами параметров. Итак, для вывода нам просто нужно сохранить параметры model.state_dict().

torch.save(model.state_dict(), filepath)

И для использования более поздних версий model.load_state_dict (torch.load (filepath)) model.eval ()

Примечание: не забудьте последнюю строку model.eval(), это очень важно после загрузки модели.

И не пытайтесь экономить torch.save(model.parameters(), filepath). model.parameters() - это просто объект-генератор.

С другой стороны, torch.save(model, filepath) сохраняет сам объект модели, но имейте в виду, что модель не имеет state_dict оптимизатора. Проверьте другой отличный ответ @Jadiel de Armas, чтобы сохранить состояние оптимизатора.

person prosti    schedule 17.04.2019
comment
Хотя это не однозначное решение, суть проблемы глубоко проанализирована! Голосовать за. - person Jason Young; 02.06.2020

Распространенное соглашение PyTorch - сохранять модели с расширением файла .pt или .pth.

Сохранить / загрузить всю модель

Сохранить:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Загрузка:

(Класс модели должен быть где-то определен)

model.load_state_dict(torch.load(PATH))
model.eval()
person harsh    schedule 13.05.2019
comment
он поднял: AttributeError: объект 'dict' не имеет атрибута 'eval' - person DennisLi; 05.05.2021

Если вы хотите сохранить модель и хотите продолжить обучение позже:

Один графический процессор. Сохранить:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Нагрузка:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Несколько графических процессоров: сохранить

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Нагрузка:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
person Joy Mazumder    schedule 21.05.2020

Локальное сохранение

Как вы сохраните свою модель, зависит от того, как вы хотите получить к ней доступ в будущем. Если вы можете вызвать новый экземпляр класса model, тогда все, что вам нужно сделать, это сохранить / загрузить веса модели с помощью model.state_dict():

# Save:
torch.save(old_model.state_dict(), PATH)

# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))

Если вы не можете по какой-либо причине (или предпочитаете более простой синтаксис), вы можете сохранить всю модель (фактически ссылку на файл (ы), определяющие модель, вместе с его state_dict) с помощью torch.save():

# Save:
torch.save(old_model, PATH)

# Load:
new_model = torch.load(PATH)

Но поскольку это ссылка на расположение файлов, определяющих класс модели, этот код не переносится, если эти файлы также не перенесены в ту же структуру каталогов.

Сохранение в облако - TorchHub

Если вы хотите, чтобы ваша модель была портативной, вы можете легко разрешить ее импорт с помощью torch.hub. Если вы добавите должным образом определенный hubconf.py файл в репозиторий github, его можно легко вызвать из PyTorch, чтобы пользователи могли загружать вашу модель с весами или без них:

hubconf.py (github.com/repo_owner/repo_name)

dependencies = ['torch']
from my_module import mymodel as _mymodel

def mymodel(pretrained=False, **kwargs):
    return _mymodel(pretrained=pretrained, **kwargs)

Модель загрузки:

new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)
person iacob    schedule 02.04.2021

В наши дни все написано в официальном руководстве: https://pytorch.org/tutorials/beginner/saving_loading_models.html

У вас есть несколько вариантов, как сохранить и что сохранить, и все это объясняется в этом руководстве.

person bruziuz    schedule 15.07.2021