source: git/machine_learning/model/predictor.py @ f59883

spielwiese
Last change on this file since f59883 was f59883, checked in by Murray Heymann <heymann.murray@…>, 5 years ago
Expand testing
  • Property mode set to 100644
File size: 4.2 KB
Line 
1"""
2Define the predictor class for classifying according to help page.
3"""
4
5# import cProfile
6import os
7import sys
8import time
9
10# Third party imports
11import numpy as np
12from sklearn.base import BaseEstimator, ClassifierMixin
13
14# Local imports
15from common.keyword_vector import vector_distance, count_occurances, \
16        read_dictionary, normalise_vector
17from common.lookuptable import create_table
18from common.constants import KEYWORDS_FILE
19
20
21class HelpPagePredictor(BaseEstimator, ClassifierMixin):
22    """
23    Classifier to assign the given input to a Singular helppage.
24    """
25    def __init__(self):
26        """
27        Define attributes
28        """
29        self.vectors = None
30        self.files = None
31
32
33    def fit(self, X, y): # pylint: disable=invalid-name
34        """
35        Setup the correspondence of vectors to help-files
36        """
37        assert X is not None, "Please provide data for X"
38        assert y is not None, "Please provide data for y"
39        self.vectors = X
40        self.files = y
41        return self
42
43
44    def predict(self, X): # pylint: disable=invalid-name
45        """
46        Classify the input vectors
47        """
48        assert X is not None, "Please provide data for X"
49        ret_list = []
50        for x in X: # pylint: disable=invalid-name
51            # find the closest vector
52            min_val = float("inf")
53            index = -1
54            i = 0
55            for vec in self.vectors:
56                # dist = vector_distance(x, vec)
57                # Dot product is much faster
58                dist = -1 * np.dot(x, vec)
59                if dist < min_val:
60                    min_val = dist
61                    index = i
62                i = i + 1
63
64            # find corresponding filename
65            file = self.files[index]
66            ret_list.append(file)
67        return np.array(ret_list)
68
69
70def basic_vector_tests():
71    """
72    Some basic sanity tests
73    """
74    predictor = HelpPagePredictor()
75    vector1 = normalise_vector(np.array([1, 4, 10]))
76    vector2 = normalise_vector(np.array([2, 3, 1]))
77    vector3 = normalise_vector(np.array([3, 9, 3]))
78
79    vectors = np.array([vector1, vector2, vector3])
80    files = np.array(["file1", "file2", "file3"])
81    print(vectors)
82    print(files)
83    print()
84
85    testvec = normalise_vector(np.array([1, 1, 1]))
86    print("test vector:")
87    print(testvec)
88    print()
89
90    print("distance to 1")
91    print(vector_distance(testvec, vector1))
92    print()
93    print("distance to 2")
94    print(vector_distance(testvec, vector2))
95    print()
96    print("distance to 3")
97    print(vector_distance(testvec, vector3))
98    print()
99
100    predictor.fit(vectors, files)
101    prediction = predictor.predict(np.array([testvec]))
102    print("Prediction:")
103    print(prediction)
104    print()
105
106
107def main():
108    """
109    Run some basic tests
110    """
111    print("Running some tests")
112
113    basic_vector_tests()
114
115    dictionary = read_dictionary(KEYWORDS_FILE)
116
117    start = time.time()
118    vectors, file_list = create_table(dictionary=dictionary)
119    end = time.time()
120    print(end - start, "seconds to create_table")
121
122    predictor = HelpPagePredictor()
123    predictor.fit(vectors, file_list)
124
125    start = time.time()
126    test_vec = count_occurances("extract.lib", dictionary)
127    prediction = predictor.predict(np.array([test_vec]))
128    end = time.time()
129    print(end - start, "seconds to make prediction")
130    print(prediction)
131    print()
132
133    print("prediction for zero vector")
134    start = time.time()
135    zerovec = np.zeros(len(dictionary) - 2)
136    prediction = predictor.predict(np.array([zerovec]))
137    end = time.time()
138    print(end - start, "seconds to make prediction")
139    print(prediction)
140    print()
141
142    if len(sys.argv) >= 2:
143        for i in range(len(sys.argv)):
144            if i == 0:
145                continue
146            if not os.path.isfile(sys.argv[i]):
147                continue
148            print("predicting for file", sys.argv[i])
149            start = time.time()
150            test_vec = count_occurances(sys.argv[i], dictionary)
151            prediction = predictor.predict(np.array([test_vec]))
152            end = time.time()
153            print(end - start, "seconds to make prediction")
154            print(prediction)
155            print()
156
157
158if __name__ == '__main__':
159    #cProfile.run("main()")
160    main()
Note: See TracBrowser for help on using the repository browser.