source: git/machine_learning/ml_python/predictor_runner.py @ 0696a3

spielwiese
Last change on this file since 0696a3 was 0696a3, checked in by Murray Heymann <heymann.murray@…>, 5 years ago
Implement C prediction function
  • Property mode set to 100644
File size: 2.5 KB
Line 
1"""
2A script to demonstrate that how predictor works
3"""
4
5import os
6import sys
7import time
8import numpy as np
9
10from model.predictor import HelpPagePredictor
11from common.keyword_vector import read_dictionary, count_occurances
12from common.lookuptable import create_table
13from common.constants import KEYWORDS_FILE
14
15def find_prediction(filename):
16    """
17    Given a file name as string, get the predicted help page name
18    """
19    dictionary = read_dictionary(KEYWORDS_FILE)
20
21    start = time.time()
22    vectors, file_list = create_table(dictionary=dictionary)
23    end = time.time()
24    print(end - start, "seconds to create_table")
25
26    start = time.time()
27    pred = get_prediction(filename, dictionary, vectors, file_list)
28    end = time.time()
29    print(end - start, "seconds to make prediction.")
30    return pred
31
32def get_prediction(filename, dictionary, vectors, file_list):
33    """
34    Train a predictor, get the predicted help page name
35    """
36    predictor = HelpPagePredictor()
37    predictor.fit(vectors, file_list)
38
39    test_vec = count_occurances(filename, dictionary)
40    prediction = predictor.predict(np.array([test_vec]))
41    return prediction[0]
42
43
44def main():
45    """
46    Run some basic tests
47    """
48    print("Running some tests")
49
50    dictionary = read_dictionary(KEYWORDS_FILE)
51
52    start = time.time()
53    vectors, file_list = create_table(dictionary=dictionary)
54    end = time.time()
55    print(end - start, "seconds to create_table")
56
57    predictor = HelpPagePredictor()
58    predictor.fit(vectors, file_list)
59
60    print("prediction for zero vector")
61    start = time.time()
62    zerovec = np.zeros(len(dictionary))
63    prediction = predictor.predict(np.array([zerovec]))
64    end = time.time()
65    print(end - start, "seconds to make prediction")
66    print(prediction)
67    print()
68
69    prediction = get_prediction("extract.lib",
70                                dictionary,
71                                vectors,
72                                file_list)
73    print(prediction)
74    print()
75
76
77    if len(sys.argv) >= 2:
78        for i in range(len(sys.argv)):
79            if i == 0:
80                continue
81            if not os.path.isfile(sys.argv[i]):
82                continue
83
84            print("predicting for file", sys.argv[i])
85            prediction = get_prediction(sys.argv[i],
86                                        dictionary,
87                                        vectors,
88                                        file_list)
89            print(prediction)
90            print()
91
92if __name__ == '__main__':
93    main()
Note: See TracBrowser for help on using the repository browser.