| """ | |
| Given a tab seperated file (.tsv) with parallel sentences, where the second column is the translation of the sentence in the first column, for example, in the format: | |
| src1 trg1 | |
| src2 trg2 | |
| ... | |
| where trg_i is the translation of src_i. | |
| Given src_i, the TranslationEvaluator checks which trg_j has the highest similarity using cosine similarity. If i == j, we assume | |
| a match, i.e., the correct translation has been found for src_i out of all possible target sentences. | |
| It then computes an accuracy over all possible source sentences src_i. Equivalently, it computes also the accuracy for the other direction. | |
| A high accuracy score indicates that the model is able to find the correct translation out of a large pool with sentences. | |
| Usage: | |
| python [model_name_or_path] [parallel-file1] [parallel-file2] ... | |
| For example: | |
| python distiluse-base-multilingual-cased TED2020-en-de.tsv.gz | |
| See the training_multilingual/get_parallel_data_...py scripts for getting parallel sentence data from different sources | |
| """ | |
| from sentence_transformers import SentenceTransformer, evaluation, LoggingHandler | |
| import sys | |
| import gzip | |
| import os | |
| import logging | |
| logging.basicConfig(format='%(asctime)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S', | |
| level=logging.INFO, | |
| handlers=[LoggingHandler()]) | |
| logger = logging.getLogger(__name__) | |
| model_name = sys.argv[1] | |
| filepaths = sys.argv[2:] | |
| inference_batch_size = 32 | |
| model = SentenceTransformer(model_name) | |
| for filepath in filepaths: | |
| src_sentences = [] | |
| trg_sentences = [] | |
| with gzip.open(filepath, 'rt', encoding='utf8') if filepath.endswith('.gz') else open(filepath, 'r', encoding='utf8') as fIn: | |
| for line in fIn: | |
| splits = line.strip().split('\t') | |
| if len(splits) >= 2: | |
| src_sentences.append(splits[0]) | |
| trg_sentences.append(splits[1]) | |
| logger.info(os.path.basename(filepath)+": "+str(len(src_sentences))+" sentence pairs") | |
| dev_trans_acc = evaluation.TranslationEvaluator(src_sentences, trg_sentences, name=os.path.basename(filepath), batch_size=inference_batch_size) | |
| dev_trans_acc(model) | |