| import torch |
| from torch import nn |
| from transformers import DistilBertModel, DistilBertTokenizer |
|
|
| class MultilabelClassifier(nn.Module): |
| """Base model for multilabel classification supporting different backbones""" |
| |
| def __init__(self, model_name, num_labels, dropout=0.1): |
| super(MultilabelClassifier, self).__init__() |
| |
| self.backbone = DistilBertModel.from_pretrained(model_name) |
| self.dropout = nn.Dropout(dropout) |
| self.classifier = nn.Linear(768, num_labels) |
| self.sigmoid = nn.Sigmoid() |
| |
| def forward(self, input_ids, attention_mask): |
| outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
| pooled_output = outputs.last_hidden_state[:, 0] |
| x = self.dropout(pooled_output) |
| logits = self.classifier(x) |
| return self.sigmoid(logits) |
|
|