Ошибка времени выполнения для (предположительно) пустых партий в Pytorch

(Обновлено с дополнительными сведениями о проблеме)

У меня есть набор данных из 3000 изображений, который попадает в DataLoader со следующими строками:

training_left_eyes = torch.utils.data.DataLoader(train_dataset, batch_size=2,shuffle=True, drop_last=True)
print(len(training_left_eyes)) #Outputs 1500

Мой цикл обучения выглядит так:

for i,(data,output) in enumerate(training_left_eyes):
      data,output = data.to(device),output.to(device)
      prediction = net(data)

      loss = costFunc(prediction,output)
      closs = loss.item()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      print("batch #{}".format(i))
      if i%100 == 0:
        print('[%d %d] loss: %.8f' % (epoch+1,i+1,closs/1000))
        closs = 0

Информация внутри тензоров «данные» и «вывод» (метки) верна, и система работает нормально, пока не достигнет номера партии 1500. Все мои партии заполнены как 3000/2 = 1500 без остатка. Как только он достигает этой последней партии, возникает ошибка RunTimeError, указывающая, что существует 0-мерный размер ввода. Но я не знаю, почему это должно происходить, поскольку enumerate (training_left_eyes) должен перебирать значения DataLoader, которые заполнены.

Я искал в Интернете, как решить эту проблему, и некоторые люди упоминали атрибут drop_last = True в DataLoader, хотя это было сделано для того, чтобы полупустые партии не попадали в модель, я все равно пробовал безрезультатно.

У меня слишком сильное зрение, и я не могу решить проблему самостоятельно. Я мог бы просто вставить оператор if, но я думаю, что это было бы очень плохой практикой, и я хочу узнать правильное решение.

Если это поможет, вот мой собственный DataSet:

class LeftEyeDataset(torch.utils.data.Dataset):
  """Left eye retinography dataset. Normal/Not-normal"""

  def __init__(self, csv_file, root_dir, transform=None):
    """
    Args:
        csv_file (string): Path to the csv file with annotations.
        root_dir (string): Directory with all the images.
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """
    self.labels  = label_mapping(csv_file)
    self.root_dir = root_dir
    self.transform = transform
    self.names = name_mapping(csv_file)

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
        idx = idx.tolist()

    img_name = self.root_dir +'/'+ self.names[idx]
    image = io.imread(img_name)
    label = self.labels[idx]

    if self.transform:
        image = self.transform(image)

    return image,label


def label_mapping(csv_file) -> np.array:
  df = read_excel(excel_file, 'Sheet1')
  x= []
  for key,value in df['Left-Diagnostic Keywords'].iteritems():
    if value=='normal fundus':
      x.append(1)
    else:
      x.append(0)
  x_tensor = torch.LongTensor(x)
  return x_tensor

def name_mapping(csv_file) -> list:
  #Reads the names of the excel file
  df = read_excel(excel_file, 'Sheet1')
  names= list()
  serie = df['Left-Fundus']
  for i in range(serie.size):
    names.append(df['Left-Fundus'][i])
  return names

При необходимости я могу предоставить любой дополнительный код.

Обновление: через некоторое время, пытаясь решить проблему, мне удается точно определить, что происходит. По какой-то причине в последнем пакете данные, поступающие в сеть, в порядке, но прямо перед первым слоем что-то происходит, и они исчезают. На следующем изображении вы можете увидеть отпечаток, который я делаю перед вводом forward (self, x) и один сразу после него. Размеры выровнены до номера партии 61 (я уменьшил его с 1500 для этого примера), в котором он каким-то образом проходит печать дважды. После этой строки появляется указанная выше ошибка.

Снимок экрана с попыткой увидеть, что происходит с данными


person Carlos Hernandez Perez    schedule 23.11.2019    source источник


Ответы (1)


После небольшой отладки резиновой утки я понял, что проблема связана не с обучением, а с набором проверки. Код читается

for i,(data,output) in validate_left_eyes:
      data,output = data.to(device),output.to(device)
      prediction = net(data)

Validate_left_eyes_ не был обернут функцией enumerate (), и поэтому первый пакет данных был пуст. Затем проблема была решена.

Прошу прощения, так как эта часть кода не была упомянута в моем вопросе, и поэтому ответ был не таким однозначным.

person Carlos Hernandez Perez    schedule 24.11.2019