class BertClassifier(nn.Module):
#Bert Model for Classification Tasks.
def __init__(self, freeze_bert=False):
"""
@param bert: a BertModel object
@param classifier: a torch.nn.Module classifier
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
"""
super(BertClassifier, self).__init__()
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
D_in, H, D_out = 768, 50, 2
# Instantiate BERT model
self.bert = BertModel.from_pretrained('bert-base-uncased')
# Instantiate an one-layer feed-forward classifier
self.classifier = nn.Sequential(
nn.Linear(D_in, H),
nn.ReLU(),
#nn.Dropout(0.5),
nn.Linear(H, D_out)
)
# Freeze the BERT model
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
Я хочу знать, что если мы вызываем bert_classifier = BertClassifier(freeze_bert=False)
, что означает, что мы не замораживаем веса, значит ли это, что мы выполняем точную настройку? Пожалуйста, поправьте меня, если я ошибаюсь.