1 | """ |
---|
2 | Define the predictor class for classifying according to help page. |
---|
3 | """ |
---|
4 | |
---|
5 | # import cProfile |
---|
6 | import os |
---|
7 | import sys |
---|
8 | import time |
---|
9 | |
---|
10 | # Third party imports |
---|
11 | import numpy as np |
---|
12 | from sklearn.base import BaseEstimator, ClassifierMixin |
---|
13 | |
---|
14 | # Local imports |
---|
15 | from common.keyword_vector import vector_distance, count_occurances, \ |
---|
16 | read_dictionary, normalise_vector |
---|
17 | from common.lookuptable import create_table |
---|
18 | from common.constants import KEYWORDS_FILE |
---|
19 | |
---|
20 | |
---|
21 | class HelpPagePredictor(BaseEstimator, ClassifierMixin): |
---|
22 | """ |
---|
23 | Classifier to assign the given input to a Singular helppage. |
---|
24 | """ |
---|
25 | def __init__(self): |
---|
26 | """ |
---|
27 | Define attributes |
---|
28 | """ |
---|
29 | self.vectors = None |
---|
30 | self.files = None |
---|
31 | |
---|
32 | |
---|
33 | def fit(self, X, y): # pylint: disable=invalid-name |
---|
34 | """ |
---|
35 | Setup the correspondence of vectors to help-files |
---|
36 | """ |
---|
37 | assert X is not None, "Please provide data for X" |
---|
38 | assert y is not None, "Please provide data for y" |
---|
39 | self.vectors = X |
---|
40 | self.files = y |
---|
41 | return self |
---|
42 | |
---|
43 | |
---|
44 | def predict(self, X): # pylint: disable=invalid-name |
---|
45 | """ |
---|
46 | Classify the input vectors |
---|
47 | """ |
---|
48 | assert X is not None, "Please provide data for X" |
---|
49 | ret_list = [] |
---|
50 | for x in X: # pylint: disable=invalid-name |
---|
51 | # find the closest vector |
---|
52 | min_val = float("inf") |
---|
53 | index = -1 |
---|
54 | i = 0 |
---|
55 | for vec in self.vectors: |
---|
56 | # dist = vector_distance(x, vec) |
---|
57 | # Dot product is much faster |
---|
58 | dist = -1 * np.dot(x, vec) |
---|
59 | if dist < min_val: |
---|
60 | min_val = dist |
---|
61 | index = i |
---|
62 | i = i + 1 |
---|
63 | |
---|
64 | # find corresponding filename |
---|
65 | file = self.files[index] |
---|
66 | ret_list.append(file) |
---|
67 | return np.array(ret_list) |
---|
68 | |
---|
69 | |
---|
70 | def basic_vector_tests(): |
---|
71 | """ |
---|
72 | Some basic sanity tests |
---|
73 | """ |
---|
74 | predictor = HelpPagePredictor() |
---|
75 | vector1 = normalise_vector(np.array([1, 4, 10])) |
---|
76 | vector2 = normalise_vector(np.array([2, 3, 1])) |
---|
77 | vector3 = normalise_vector(np.array([3, 9, 3])) |
---|
78 | |
---|
79 | vectors = np.array([vector1, vector2, vector3]) |
---|
80 | files = np.array(["file1", "file2", "file3"]) |
---|
81 | print(vectors) |
---|
82 | print(files) |
---|
83 | print() |
---|
84 | |
---|
85 | testvec = normalise_vector(np.array([1, 1, 1])) |
---|
86 | print("test vector:") |
---|
87 | print(testvec) |
---|
88 | print() |
---|
89 | |
---|
90 | print("distance to 1") |
---|
91 | print(vector_distance(testvec, vector1)) |
---|
92 | print() |
---|
93 | print("distance to 2") |
---|
94 | print(vector_distance(testvec, vector2)) |
---|
95 | print() |
---|
96 | print("distance to 3") |
---|
97 | print(vector_distance(testvec, vector3)) |
---|
98 | print() |
---|
99 | |
---|
100 | predictor.fit(vectors, files) |
---|
101 | prediction = predictor.predict(np.array([testvec])) |
---|
102 | print("Prediction:") |
---|
103 | print(prediction) |
---|
104 | print() |
---|
105 | |
---|
106 | |
---|
107 | def main(): |
---|
108 | """ |
---|
109 | Run some basic tests |
---|
110 | """ |
---|
111 | print("Running some tests") |
---|
112 | |
---|
113 | basic_vector_tests() |
---|
114 | |
---|
115 | dictionary = read_dictionary(KEYWORDS_FILE) |
---|
116 | |
---|
117 | start = time.time() |
---|
118 | vectors, file_list = create_table(dictionary=dictionary) |
---|
119 | end = time.time() |
---|
120 | print(end - start, "seconds to create_table") |
---|
121 | |
---|
122 | test_vec = count_occurances("extract.lib", dictionary) |
---|
123 | predictor = HelpPagePredictor() |
---|
124 | predictor.fit(vectors, file_list) |
---|
125 | |
---|
126 | start = time.time() |
---|
127 | prediction = predictor.predict(np.array([test_vec])) |
---|
128 | end = time.time() |
---|
129 | print(end - start, "seconds to make prediction") |
---|
130 | print(prediction) |
---|
131 | print() |
---|
132 | |
---|
133 | print("prediction for zero vector") |
---|
134 | zerovec = np.zeros(len(dictionary) - 2) |
---|
135 | start = time.time() |
---|
136 | prediction = predictor.predict(np.array([zerovec])) |
---|
137 | end = time.time() |
---|
138 | print(end - start, "seconds to make prediction") |
---|
139 | print(prediction) |
---|
140 | print() |
---|
141 | |
---|
142 | if len(sys.argv) >= 2: |
---|
143 | for i in range(len(sys.argv)): |
---|
144 | if i == 0: |
---|
145 | continue |
---|
146 | if not os.path.isfile(sys.argv[i]): |
---|
147 | continue |
---|
148 | print("predicting for file", sys.argv[i]) |
---|
149 | test_vec = count_occurances(sys.argv[i], dictionary) |
---|
150 | start = time.time() |
---|
151 | prediction = predictor.predict(np.array([test_vec])) |
---|
152 | end = time.time() |
---|
153 | print(end - start, "seconds to make prediction") |
---|
154 | print(prediction) |
---|
155 | print() |
---|
156 | |
---|
157 | |
---|
158 | if __name__ == '__main__': |
---|
159 | #cProfile.run("main()") |
---|
160 | main() |
---|