Библиотека Python pickle реализует двоичные протоколы для сериализации и десериализации объекта Python. .
Когда вы import torch
(или когда вы используете PyTorch), он будет import pickle
для вас, и вам не нужно напрямую вызывать pickle.dump()
и pickle.load()
, которые являются методами для сохранения и загрузки объекта.
Фактически, torch.save()
и torch.load()
обернут для вас pickle.dump()
и pickle.load()
.
state_dict
другой упомянутый ответ заслуживает еще нескольких заметок.
Что state_dict
у нас внутри PyTorch? На самом деле есть два state_dict
.
Модель PyTorch torch.nn.Module
, которая имеет model.parameters()
вызов для получения обучаемых параметров (w и b). Эти изучаемые параметры, однажды установленные случайным образом, будут обновляться с течением времени по мере нашего обучения. Обучаемые параметры - это первые state_dict
.
Второй state_dict
- это определение состояния оптимизатора. Вы помните, что оптимизатор используется для улучшения наших обучаемых параметров. Но оптимизатор state_dict
исправлен. Здесь нечему учиться.
Поскольку state_dict
объекты являются словарями Python, их можно легко сохранять, обновлять, изменять и восстанавливать, добавляя большую модульность моделям и оптимизаторам PyTorch.
Давайте создадим суперпростую модель, чтобы объяснить это:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Этот код выведет следующее:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Обратите внимание, это минимальная модель. Вы можете попробовать добавить стопку последовательных
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Обратите внимание, что только слои с обучаемыми параметрами (сверточные слои, линейные слои и т. Д.) И зарегистрированными буферами (слои батчнорма) имеют записи в state_dict
модели.
Необучаемые вещи принадлежат объекту оптимизатора state_dict
, который содержит информацию о состоянии оптимизатора, а также об используемых гиперпараметрах.
В остальном история такая же; на этапе вывода (это этап, когда мы используем модель после обучения) для прогнозирования; мы делаем прогнозы на основе изученных нами параметров. Итак, для вывода нам просто нужно сохранить параметры model.state_dict()
.
torch.save(model.state_dict(), filepath)
И для использования более поздних версий model.load_state_dict (torch.load (filepath)) model.eval ()
Примечание: не забудьте последнюю строку model.eval()
, это очень важно после загрузки модели.
И не пытайтесь экономить torch.save(model.parameters(), filepath)
. model.parameters()
- это просто объект-генератор.
С другой стороны, torch.save(model, filepath)
сохраняет сам объект модели, но имейте в виду, что модель не имеет state_dict
оптимизатора. Проверьте другой отличный ответ @Jadiel de Armas, чтобы сохранить состояние оптимизатора.
person
prosti
schedule
17.04.2019
torch.save(model, f)
иtorch.save(model.state_dict(), f)
. Сохраненные файлы имеют одинаковый размер. Теперь я запутался. Кроме того, я обнаружил, что использование pickle для сохранения model.state_dict () очень медленное. Я думаю, что лучший способ - использоватьtorch.save(model.state_dict(), f)
, поскольку вы занимаетесь созданием модели, а torch обрабатывает загрузку весов модели, тем самым устраняя возможные проблемы. Ссылка: Discussion.pytorch.org/t/saving-torch-models/ 838/4 - person Dawei Yang   schedule 29.03.2017pickle
? - person Charlie Parker   schedule 13.07.2020nn.Sequential model
. Вы знаете, как это сделать? У меня нет определения класса модели. Для последовательного, я написал это, надеюсь, уважаемый ответчик подтвердит: stackoverflow.com/questions/62923052/ - person Charlie Parker   schedule 15.07.2020