Model () получила несколько значений для аргумента 'nr_class' - мультиклассификационная модель SpaCy (интеграция BERT)

Привет, я работаю над реализацией модели с несколькими классификациями (5 классов) с новой моделью SpaCy en_pytt_bertbaseuncased_lg. Код для новой трубы здесь:

nlp = spacy.load('en_pytt_bertbaseuncased_lg')
textcat = nlp.create_pipe(
    'pytt_textcat',
    config={
        "nr_class":5,
        "exclusive_classes": True,
    }
)
nlp.add_pipe(textcat, last = True)

textcat.add_label("class1")
textcat.add_label("class2")
textcat.add_label("class3")
textcat.add_label("class4")
textcat.add_label("class5")

Код для обучения следующий и основан на примере отсюда (https://pypi.org/project/spacy-pytorch-transformers/):

def extract_cat(x):
    for key in x.keys():
        if x[key]:
            return key

# get names of other pipes to disable them during training
n_iter = 250 # number of epochs

train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))


dev_cats_single   = [extract_cat(x) for x in dev_cats]
train_cats_single = [extract_cat(x) for x in train_cats]
cats = list(set(train_cats_single))
recall = {}
for c in cats:
    if c is not None: 
        recall['dev_'+c] = []
        recall['train_'+c] = []



optimizer = nlp.resume_training()
batch_sizes = compounding(1.0, round(len(train_texts)/2), 1.001)

for i in range(n_iter):
    random.shuffle(train_data)
    losses = {}
    batches = minibatch(train_data, size=batch_sizes)
    for batch in batches:
        texts, annotations = zip(*batch)
        nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
    print(i, losses)

Итак, структура моих данных выглядит так:

[('TEXT TEXT TEXT',
  {'cats': {'class1': False,
    'class2': False,
    'class3': False,
    'class4': True,
    'class5': False}}), ... ]

Я не уверен, почему я получаю следующую ошибку:

TypeError                                 Traceback (most recent call last)
<ipython-input-32-1588a4eadc8d> in <module>
     21 
     22 
---> 23 optimizer = nlp.resume_training()
     24 batch_sizes = compounding(1.0, round(len(train_texts)/2), 1.001)
     25 

TypeError: Model() got multiple values for argument 'nr_class'

РЕДАКТИРОВАТЬ:

если я уберу аргумент nr_class, я получу вот эту ошибку:

ValueError: operands could not be broadcast together with shapes (1,2) (1,5)

Я действительно думал, что это произойдет, потому что я не указал аргумент nr_class. Это правильно?


person Henryk Borzymowski    schedule 13.08.2019    source источник


Ответы (1)


Это регресс в самой последней выпущенной нами версии spacy-pytorch-transformers. Извини за это!

Основная причина в том, что это еще один случай зла **kwargs. Я с нетерпением жду возможности усовершенствовать API spaCy, чтобы предотвратить эти проблемы в будущем.

Вы можете увидеть оскорбительную строчку здесь: : //github.com/explosion/spacy-pytorch-transformers/blob/c1def95e1df783c69bff9bc8b40b5461800e9231/spacy_pytorch_transformers/pipeline/textcat.py#L71. Мы предоставляем позиционный аргумент nr_class, который перекрывается с явным аргументом, который вы передали во время конфигурации.

Чтобы обойти проблему, вы можете просто удалить ключ nr_class из своего config dict, который вы передаете в spacy.create_pipe().

person syllogism_    schedule 13.08.2019
comment
Я понимаю, но дело в том, что если я уберу аргумент nr_class, я получу вот эту ошибку: ValueError: операнды не могут транслироваться вместе с shape (1,2) (1,5) Я действительно думал, что это произойдет, потому что я не указал аргумент nr_class. Это правильно? - person Henryk Borzymowski; 13.08.2019
comment
@HenrykBorzymowski У меня такая же проблема. Вы можете добавить architecture='softmax_pooler_output' в config dict при создании канала, и это, вероятно, сработает. Однако другие архитектуры, такие как softmax_class_vector или softmax_last_hidden, выдают эту ошибку в случае мультиклассовой классификации. Я изучил исходный код и не мог понять, откуда берется этот 2 в случае использования softmax_class_vector. - person today; 19.08.2019
comment
Было бы здорово, если бы вы могли взглянуть на эту актуальную проблему на Github, @syllogism_. - person today; 19.08.2019