Advertisement
here2share

# FacebookTransCoder.py

Feb 20th, 2021
1,685
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.66 KB | None | 0 0
  1. # FacebookTransCoder.py
  2.  
  3. # Copyright (c) 2019-present, Facebook, Inc.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. #
  9. # Translate sentences from the input stream.
  10. # The model will be faster is sentences are sorted by length.
  11. # Input sentences must have the same tokenization and BPE codes than the ones used in the model.
  12. #
  13. # Usage:
  14. #     python translate.py
  15. #     --src_lang cpp --tgt_lang java \
  16. #     --model_path trained_model.pth < input_code.cpp
  17. #
  18.  
  19. import argparse
  20. import os
  21. import sys
  22.  
  23. import fastBPE
  24. import torch
  25.  
  26. import preprocessing.src.code_tokenizer as code_tokenizer
  27. from XLM.src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
  28. from XLM.src.model import build_model
  29. from XLM.src.utils import AttrDict
  30.  
  31. SUPPORTED_LANGUAGES = ['cpp', 'java', 'python']
  32.  
  33.  
  34. def get_parser():
  35.     """
  36.    Generate a parameters parser.
  37.    """
  38.     # parse parameters
  39.     parser = argparse.ArgumentParser(description="Translate sentences")
  40.  
  41.     # model
  42.     parser.add_argument("--model_path", type=str,
  43.                         default="", help="Model path")
  44.     parser.add_argument("--src_lang", type=str, default="",
  45.                         help=f"Source language, should be either {', '.join(SUPPORTED_LANGUAGES[:-1])} or {SUPPORTED_LANGUAGES[-1]}")
  46.     parser.add_argument("--tgt_lang", type=str, default="",
  47.                         help=f"Target language, should be either {', '.join(SUPPORTED_LANGUAGES[:-1])} or {SUPPORTED_LANGUAGES[-1]}")
  48.     parser.add_argument("--BPE_path", type=str,
  49.                         default="data/BPE_with_comments_codes", help="Path to BPE codes.")
  50.     parser.add_argument("--beam_size", type=int, default=1,
  51.                         help="Beam size. The beams will be printed in order of decreasing likelihood.")
  52.  
  53.     return parser
  54.  
  55.  
  56. class Translator:
  57.     def __init__(self, params):
  58.         reloaded = torch.load(params.model_path, map_location='cpu')
  59.         reloaded['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
  60.                                reloaded['encoder'].items()}
  61.         assert 'decoder' in reloaded or (
  62.             'decoder_0' in reloaded and 'decoder_1' in reloaded)
  63.         if 'decoder' in reloaded:
  64.             decoders_names = ['decoder']
  65.         else:
  66.             decoders_names = ['decoder_0', 'decoder_1']
  67.         for decoder_name in decoders_names:
  68.             reloaded[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
  69.                                       reloaded[decoder_name].items()}
  70.  
  71.         self.reloaded_params = AttrDict(reloaded['params'])
  72.  
  73.         # build dictionary / update parameters
  74.         self.dico = Dictionary(
  75.             reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
  76.         assert self.reloaded_params.n_words == len(self.dico)
  77.         assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD)
  78.         assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD)
  79.         assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD)
  80.         assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD)
  81.         assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD)
  82.  
  83.         # build model / reload weights
  84.         self.reloaded_params['reload_model'] = ','.join([params.model_path] * 2)
  85.         encoder, decoder = build_model(self.reloaded_params, self.dico)
  86.  
  87.         self.encoder = encoder[0]
  88.         self.encoder.load_state_dict(reloaded['encoder'])
  89.         assert len(reloaded['encoder'].keys()) == len(
  90.             list(p for p, _ in self.encoder.state_dict().items()))
  91.  
  92.         self.decoder = decoder[0]
  93.         self.decoder.load_state_dict(reloaded['decoder'])
  94.         assert len(reloaded['decoder'].keys()) == len(
  95.             list(p for p, _ in self.decoder.state_dict().items()))
  96.  
  97.         self.encoder.cuda()
  98.         self.decoder.cuda()
  99.  
  100.         self.encoder.eval()
  101.         self.decoder.eval()
  102.         self.bpe_model = fastBPE.fastBPE(os.path.abspath(params.BPE_path))
  103.  
  104.     def translate(self, input, lang1, lang2, n=1, beam_size=1, sample_temperature=None, device='cuda:0'):
  105.         with torch.no_grad():
  106.             assert lang1 in {'python', 'java', 'cpp'}, lang1
  107.             assert lang2 in {'python', 'java', 'cpp'}, lang2
  108.  
  109.             DEVICE = device
  110.             tokenizer = getattr(code_tokenizer, f'tokenize_{lang1}')
  111.             detokenizer = getattr(code_tokenizer, f'detokenize_{lang2}')
  112.             lang1 += '_sa'
  113.             lang2 += '_sa'
  114.  
  115.             lang1_id = self.reloaded_params.lang2id[lang1]
  116.             lang2_id = self.reloaded_params.lang2id[lang2]
  117.  
  118.             tokens = [t for t in tokenizer(input)]
  119.             tokens = self.bpe_model.apply(tokens)
  120.             tokens = ['</s>'] + tokens + ['</s>']
  121.             input = " ".join(tokens)
  122.             # create batch
  123.             len1 = len(input.split())
  124.             len1 = torch.LongTensor(1).fill_(len1).to(DEVICE)
  125.  
  126.             x1 = torch.LongTensor([self.dico.index(w)
  127.                                    for w in input.split()]).to(DEVICE)[:, None]
  128.             langs1 = x1.clone().fill_(lang1_id)
  129.  
  130.             enc1 = self.encoder('fwd', x=x1, lengths=len1,
  131.                                 langs=langs1, causal=False)
  132.             enc1 = enc1.transpose(0, 1)
  133.             if n > 1:
  134.                 enc1 = enc1.repeat(n, 1, 1)
  135.                 len1 = len1.expand(n)
  136.  
  137.             if beam_size == 1:
  138.                 x2, len2 = self.decoder.generate(enc1, len1, lang2_id,
  139.                                                  max_len=int(
  140.                                                      min(self.reloaded_params.max_len, 3 * len1.max().item() + 10)),
  141.                                                  sample_temperature=sample_temperature)
  142.             else:
  143.                 x2, len2 = self.decoder.generate_beam(enc1, len1, lang2_id,
  144.                                                       max_len=int(
  145.                                                           min(self.reloaded_params.max_len, 3 * len1.max().item() + 10)),
  146.                                                       early_stopping=False, length_penalty=1.0, beam_size=beam_size)
  147.             tok = []
  148.             for i in range(x2.shape[1]):
  149.                 wid = [self.dico[x2[j, i].item()] for j in range(len(x2))][1:]
  150.                 wid = wid[:wid.index(EOS_WORD)] if EOS_WORD in wid else wid
  151.                 tok.append(" ".join(wid).replace("@@ ", ""))
  152.  
  153.             results = []
  154.             for t in tok:
  155.                 results.append(detokenizer(t))
  156.             return results
  157.  
  158.  
  159. if __name__ == '__main__':
  160.     # generate parser / parse parameters
  161.     parser = get_parser()
  162.     params = parser.parse_args()
  163.  
  164.     # check parameters
  165.     assert os.path.isfile(
  166.         params.model_path), f"The path to the model checkpoint is incorrect: {params.model_path}"
  167.     assert os.path.isfile(
  168.         params.BPE_path), f"The path to the BPE tokens is incorrect: {params.BPE_path}"
  169.     assert params.src_lang in SUPPORTED_LANGUAGES, f"The source language should be in {SUPPORTED_LANGUAGES}."
  170.     assert params.tgt_lang in SUPPORTED_LANGUAGES, f"The target language should be in {SUPPORTED_LANGUAGES}."
  171.  
  172.     # Initialize translator
  173.     translator = Translator(params)
  174.  
  175.     # read input code from stdin
  176.     src_sent = []
  177.     input = sys.stdin.read().strip()
  178.  
  179.     with torch.no_grad():
  180.         output = translator.translate(
  181.             input, lang1=params.src_lang, lang2=params.tgt_lang, beam_size=params.beam_size)
  182.  
  183.     for out in output:
  184.         print("=" * 20)
  185.         print(out)
  186.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement