Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package il.ac.tau.cs.sw1.ex5;
- import java.io.BufferedReader;
- import java.io.FileReader;
- import java.io.FileWriter;
- import java.io.IOException;
- import java.io.PrintWriter;
- public class BigramModel {
- public static final int MAX_VOCABULARY_SIZE = 14500;
- public static final String VOC_FILE_SUFFIX = ".voc";
- public static final String COUNTS_FILE_SUFFIX = ".counts";
- public static final String SOME_NUM = "some_num";
- public static final int ELEMENT_NOT_FOUND = -1;
- String[] mVocabulary;
- int[][] mBigramCounts;
- // DO NOT CHANGE THIS !!!
- public void initModel(String fileName) throws IOException{
- mVocabulary = buildVocabularyIndex(fileName);
- mBigramCounts = buildCountsArray(fileName, mVocabulary);
- }
- // public static boolean isLegalWord(String word) {
- // char[] chars = word.toCharArray(); // convert word to array of characters
- // int digitsCounter = 0; // counts digits in word
- // for (char ch:chars) {
- // if (Character.isLetter(ch)) {
- // return true;
- // }
- // else if (Character.isDigit(ch)) {
- // digitsCounter += 1;
- // }
- // }
- // return (digitsCounter == chars.length); // only digits
- //
- // }
- public static boolean wordWithLetter(String word) {
- char[] chars = word.toCharArray();
- for (char ch: chars) {
- if (Character.isLetter(ch)) {
- return true;
- }
- }
- return false;
- }
- public static boolean isNumber(String word) {
- char[] chars = word.toCharArray();
- for (char ch: chars) {
- if (!Character.isDigit(ch)) {
- return false;
- }
- }
- return true;
- }
- public static boolean isLegalWord(String word) {
- return (wordWithLetter(word)||isNumber(word));
- }
- public static int getIndex(String[] vocabulary, String word) {
- int i = 0;
- while (!vocabulary[i].equals(word) && i < vocabulary.length - 1) {
- i += 1;
- }
- if (vocabulary[i].equals(word)) {
- return i;
- }
- else {
- return -1;
- }
- }
- public static void deleteDuplicates(String word, String[] arr) {
- int count = 0;
- for (int i = 0; i < arr.length; i++) {
- if (arr[i].equals(word)) {
- count += 1;
- if (count > 1) {
- arr[i] = "#"; // denote deletion
- }
- }
- }
- }
- /*
- * @post: mVocabulary = prev(mVocabulary)
- * @post: mBigramCounts = prev(mBigramCounts)
- */
- public String[] buildVocabularyIndex(String fileName) throws IOException{ // Q 1
- FileReader fr = new FileReader(fileName);
- BufferedReader br = new BufferedReader(fr);
- String text = "";
- String new_line;
- while ((new_line = br.readLine()) != null) {
- String[] words = new_line.split(" ");
- for (String word: words) {
- if (wordWithLetter(word)) {
- text += word.toLowerCase() + " ";
- }
- if (isNumber(word)) {
- text += SOME_NUM + " ";
- }
- }
- }
- String[] legals = text.split(" ");
- int countUniqueLegalWords = 0;
- for (String word: legals) {
- if (!word.equals("#")) {
- countUniqueLegalWords += 1;
- deleteDuplicates(word, legals);
- }
- }
- int length = MAX_VOCABULARY_SIZE;
- if (countUniqueLegalWords < length) {
- length = countUniqueLegalWords;
- }
- String[] vocabulary = new String[length];
- int index_to_add = 0;
- int i = 0;
- while (i < legals.length && index_to_add < vocabulary.length) {
- if (!legals[i].equals("#")) {
- vocabulary[index_to_add] = legals[i];
- index_to_add += 1;
- }
- i += 1;
- }
- return vocabulary;
- }
- /*
- * @post: mVocabulary = prev(mVocabulary)
- * @post: mBigramCounts = prev(mBigramCounts)
- */
- public int[][] buildCountsArray(String fileName, String[] vocabulary) throws IOException{ // Q - 2
- int[][] counts = new int[vocabulary.length][vocabulary.length];
- FileReader fr = new FileReader(fileName);
- BufferedReader br = new BufferedReader(fr);
- String line;
- while ((line = br.readLine()) != null) {
- String[] sentence = line.split(" "); // words in line
- for (int k = 0; k < sentence.length - 1; k++) {
- if (wordWithLetter(sentence[k]) && wordWithLetter(sentence[k+1])) {
- int i = getIndex(vocabulary, sentence[k].toLowerCase());
- int j = getIndex(vocabulary, sentence[k+1].toLowerCase());
- counts[i][j] += 1;
- }
- else if (isNumber(sentence[k]) && isNumber(sentence[k+1])) {
- int i = getIndex(vocabulary, SOME_NUM);
- int j = getIndex(vocabulary, SOME_NUM);
- counts[i][j] += 1;
- }
- }
- }
- return counts;
- }
- /*
- * @pre: the method initModel was called (the language model is initialized)
- * @pre: fileName is a legal file path
- */
- public void saveModel(String fileName) throws IOException{ // Q-3
- String vocabFile = fileName + VOC_FILE_SUFFIX;
- String countsFile = fileName + COUNTS_FILE_SUFFIX;
- PrintWriter writer1 = null;
- PrintWriter writer2 = null;
- FileWriter vfw = new FileWriter(vocabFile, true);
- FileWriter cfw = new FileWriter(countsFile, true);
- writer1 = new PrintWriter(vfw);
- writer2 = new PrintWriter(cfw);
- writer1.println(mVocabulary.length + " words");
- for (int i = 0; i < mVocabulary.length; i++) {
- writer1.println(i + "," + mVocabulary[i]);
- }
- writer1.close();
- for (int i = 0; i < mBigramCounts.length; i++) {
- for (int j = 0; j < mBigramCounts[0].length; j++) {
- if (mBigramCounts[i][j] > 0) {
- writer2.println(i + "," + j + ":" + mBigramCounts[i][j]);
- }
- }
- }
- writer2.close();
- }
- /*
- * @pre: fileName is a legal file path
- */
- public void loadModel(String fileName) throws IOException{ // Q - 4
- // vocabulary
- FileReader vocFr = new FileReader(fileName + VOC_FILE_SUFFIX);
- BufferedReader vocBr = new BufferedReader(vocFr);
- String firstLine = vocBr.readLine();
- String[] firstLineArr = firstLine.split(" " );
- int numberOfWords = Integer.parseInt(firstLineArr[0]);
- String[] vocabulary = new String[numberOfWords];
- String newLine;
- while ((newLine = vocBr.readLine()) != null) {
- String[] newLineArr = newLine.split(",");
- int i = Integer.parseInt(newLineArr[0]);
- String word = newLineArr[1];
- vocabulary[i] = word;
- }
- // bigram counts
- int[][] counts = new int[vocabulary.length][vocabulary.length];
- FileReader countsFr = new FileReader(fileName + COUNTS_FILE_SUFFIX);
- BufferedReader countsBr = new BufferedReader(countsFr);
- String line;
- while ((line = countsBr.readLine()) != null) {
- String[] lineParts = line.split(":");
- String part1 = lineParts[0];
- int value = Integer.parseInt(lineParts[1]);
- String[] indices = part1.split(",");
- int i_index = Integer.parseInt(indices[0]);
- int j_index = Integer.parseInt(indices[1]);
- counts[i_index][j_index] = value;
- }
- mVocabulary = vocabulary;
- mBigramCounts = counts;
- }
- /*
- * @pre: word is in lowercase
- * @pre: the method initModel was called (the language model is initialized)
- * @pre: word is in lowercase
- * @post: $ret = -1 if word is not in vocabulary, otherwise $ret = the index of word in vocabulary
- */
- public int getWordIndex(String word){ // Q - 5
- return getIndex(mVocabulary, word);
- }
- /*
- * @pre: word1, word2 are in lowercase
- * @pre: the method initModel was called (the language model is initialized)
- * @post: $ret = the count for the bigram <word1, word2>. if one of the words does not
- * exist in the vocabulary, $ret = 0
- */
- public int getBigramCount(String word1, String word2){ // Q - 6
- int findResult1 = getWordIndex(word1);
- int findResult2 = getWordIndex(word2);
- if (findResult1 == -1 || findResult2 == -1) {
- return 0;
- }
- else { // both words in vocabulary
- return mBigramCounts[findResult1][findResult2];
- }
- }
- /*
- * @pre word in lowercase, and is in mVocabulary
- * @pre: the method initModel was called (the language model is initialized)
- * @post $ret = the word with the lowest vocabulary index that appears most fequently after word (if a bigram starting with
- * word was never seen, $ret will be null
- */
- public String getMostFrequentProceeding(String word){ // Q - 7
- int i = getWordIndex(word); // index in vocabulary
- int maxTimesProc = 0;
- String mostFrequent = "";
- for (int j = 0; j < mBigramCounts[i].length; j++) {
- if (mBigramCounts[i][j] > maxTimesProc) {
- mostFrequent = mVocabulary[j]; // word
- maxTimesProc = mBigramCounts[i][j]; // number of times
- }
- }
- if (maxTimesProc == 0) {
- return null;
- }
- else {
- return mostFrequent;
- }
- }
- /* @pre: sentence is in lowercase
- * @pre: the method initModel was called (the language model is initialized)
- * @pre: each two words in the sentence are are separated with a single space
- * @post: if sentence is is probable, according to the model, $ret = true, else, $ret = false
- */
- public boolean isLegalSentence(String sentence){ // Q - 8
- String[] sentenceWords = sentence.split(" "); // words in sentence
- if (sentenceWords.length == 0) { // empty sentence
- return true;
- }
- else if (sentenceWords.length == 1) { // one word sentence
- String word = sentenceWords[0];
- if (getWordIndex(word) == -1) {
- return false;
- }
- else {
- return true;
- }
- }
- else { // more than one word
- for (int i = 0; i < sentenceWords.length - 1; i++) {
- String word1 = sentenceWords[i]; // first word in pair
- String word2 = sentenceWords[i+1]; // second word in pair
- int word1_index = getWordIndex(word1);
- int word2_index = getWordIndex(word2);
- if (word1_index == -1 || word2_index == -1) {
- return false;
- }
- else { // both words in vocabulary
- if (mBigramCounts[word1_index][word2_index] == 0) {
- return false;
- }
- }
- }
- return true;
- }
- }
- public static boolean onlyZeros(int[] arr) {
- for (int num:arr) {
- if (num != 0) {
- return false;
- }
- }
- return true;
- }
- /*
- * @pre: arr1.length = arr2.legnth
- * post if arr1 or arr2 are only filled with zeros, $ret = -1, otherwise calcluates CosineSim
- */
- public static double calcCosineSim(int[] arr1, int[] arr2){ // Q - 9
- if (onlyZeros(arr1) || onlyZeros(arr2)) {
- return -1.;
- }
- else {
- // numerator
- double numerator = 0;
- for (int i = 0; i < arr1.length; i++) {
- numerator += arr1[i] * arr2[i];
- }
- // denominator
- double sumSquares1 = 0;
- for (int number:arr1) {
- sumSquares1 += number * number;
- }
- double sumSquares2 = 0;
- for (int number:arr2) {
- sumSquares2 += number * number;
- }
- double denominator = Math.sqrt(sumSquares1) * Math.sqrt(sumSquares2);
- return numerator / denominator;
- }
- }
- /*
- * @pre: word is in vocabulary
- * @pre: the method initModel was called (the language model is initialized),
- * @post: $ret = w implies that w is the word with the largest cosineSimilarity(vector for word, vector for w) among all the
- * other words in vocabulary
- */
- public String getClosestWord(String word){ // Q - 10
- int i = getWordIndex(word); // index of word in vocabulary
- double factor;
- String closestWord;
- if (mVocabulary.length == 1) {
- return mVocabulary[0];
- }
- else if (i == 0) { // word is first in vocabulary
- closestWord = mVocabulary[1];
- factor = calcCosineSim(mBigramCounts[i], mBigramCounts[1]);
- }
- else {
- closestWord = mVocabulary[0];
- factor = calcCosineSim(mBigramCounts[i], mBigramCounts[0]);
- }
- int[] wordVector = mBigramCounts[i]; // vector of word
- for (int j = 0; j < mBigramCounts.length; j++) {
- if (j != i) {
- String word2 = mVocabulary[j];
- int[] word2Vector = mBigramCounts[j]; // vector of second word
- double cos = calcCosineSim(wordVector, word2Vector);
- if (cos > factor) {
- closestWord = word2;
- factor = cos;
- }
- }
- }
- return closestWord;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement