Почему вывод на мою модель pytorch процессора не воспроизводится?

Недавно я начал работать с pytorch и заметил, что не получаю повторяемых / детерминированных результатов при оценке предварительно обученной модели на новых входных данных.

Я свел проблему к этому минимальному примеру, который показывает, что многократное применение одной и той же простой модели свертки не дает идентичных результатов:

import numpy as np 
import matplotlib.pyplot as plt 
import torch

device = torch.device('cpu')

# function to get all the params from a pytorch model
def getParams(model):
    a = list(model.parameters())
    b = [a[i].detach().cpu().numpy() for i in range(len(a))]
    c = [b[i].flatten() for i in range(len(b))]
    d = np.hstack(c)

    return d

# set up a simple model (9 params)
testModule = torch.nn.Conv2d(1, 1, kernel_size = (3, 3), bias = False, stride = 1, padding = 1).double()
torch.nn.init.normal_(testModule.weight, mean=0, std=1)
testModule = testModule.eval()

# set up a dummy input
patch = torch.from_numpy(np.random.randn(1,1,80,80).astype('double')).to(device)

# apply the model 100 times
testVals = []
testParams = []
testModuleOut = []
for ii in range(100):
    testParams.append(getParams(testModule))
    testModuleOut.append(testModule(patch).cpu().detach()[0,:,:,:].numpy())

testParams = np.stack(testParams)
testModuleOut = np.stack(testModuleOut)

# view the variation of the model parameters and the output values
plt.figure()
plt.plot(np.std(testParams,axis=0))
plt.xlabel('Parameter index')
plt.ylabel('Standard deviation over runs')

plt.figure()
plt.plot(np.std(testModuleOut,axis=0).ravel())
plt.xlabel('Output index')
plt.ylabel('Standard deviation over runs')

Если бы повторный запуск сети был повторяемым, я бы ожидал, что графики стандартного отклонения будут показывать плоские линии при SD = 0. Но я этого не понимаю, вместо этого я получаю несколько случайных линий графика, которые меняются с каждым запуском скрипта ( иногда параметры модуля имеют SD = 0, но сетевой выход никогда не бывает).

В чем проблема с моим кодом? Кажется, что SD ориентированы на машинную точность, но почему многократное извлечение параметров из модуля заставляет их изменяться таким образом? Разве мы не будем просто извлекать одно и то же значение из памяти?


person user11305730    schedule 03.04.2019    source источник