Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from sentence_transformers import SentenceTransformer, models
- from transformers import BertTokenizer
- import torch
- if torch.cuda.is_available():
- device = torch.device("cuda")
- else:
- device = torch.device("cpu")
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- class BertForSTS(torch.nn.Module):
- def __init__(self):
- super(BertForSTS, self).__init__()
- self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)
- self.pooling_layer = models.Pooling(self.bert.get_word_embedding_dimension())
- self.sts_bert = SentenceTransformer(modules=[self.bert, self.pooling_layer])
- def forward(self, input_data):
- output = self.sts_bert(input_data)['sentence_embedding']
- return output
- def predict_similarity(sentence_pair):
- test_input = tokenizer(sentence_pair, padding='max_length', max_length=128, truncation=True,
- return_tensors="pt").to(device)
- test_input['input_ids'] = test_input['input_ids']
- test_input['attention_mask'] = test_input['attention_mask']
- del test_input['token_type_ids']
- output = model(test_input)
- sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()
- return sim
- if __name__ == '__main__':
- PATH = 'bert-sts.pt'
- model = BertForSTS()
- model.load_state_dict(torch.load(f"{PATH}", map_location=torch.device(device=device)))
- model.eval()
- first_text = ['хочу сказать этому обэме', 'хочу сказать этому байдену']
- print(predict_similarity(first_text))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement