1 | """ |
---|
2 | Define the predictor class for classifying according to help page. |
---|
3 | """ |
---|
4 | |
---|
5 | # Third party imports |
---|
6 | import numpy as np |
---|
7 | from sklearn.base import BaseEstimator, ClassifierMixin |
---|
8 | |
---|
9 | # Local imports |
---|
10 | from common.keyword_vector import vector_distance, count_occurances, \ |
---|
11 | read_dictionary |
---|
12 | from common.lookuptable import create_table |
---|
13 | from common.constants import KEYWORDS_FILE |
---|
14 | |
---|
15 | |
---|
16 | class HelpPagePredictor(BaseEstimator, ClassifierMixin): |
---|
17 | """ |
---|
18 | Classifier to assign the given input to a Singular helppage. |
---|
19 | """ |
---|
20 | def __init__(self): |
---|
21 | """ |
---|
22 | Define attributes |
---|
23 | """ |
---|
24 | self.vectors = None |
---|
25 | self.files = None |
---|
26 | |
---|
27 | |
---|
28 | def fit(self, X, y): # pylint: disable=invalid-name |
---|
29 | """ |
---|
30 | Setup the correspondence of vectors to help-files |
---|
31 | """ |
---|
32 | assert X is not None, "Please provide data for X" |
---|
33 | assert y is not None, "Please provide data for y" |
---|
34 | self.vectors = X |
---|
35 | self.files = y |
---|
36 | return self |
---|
37 | |
---|
38 | |
---|
39 | def predict(self, X): # pylint: disable=invalid-name |
---|
40 | """ |
---|
41 | Classify the input vectors |
---|
42 | """ |
---|
43 | assert X is not None, "Please provide data for X" |
---|
44 | ret_list = [] |
---|
45 | for x in X: # pylint: disable=invalid-name |
---|
46 | # find the closest vector |
---|
47 | |
---|
48 | min_val = float("inf") |
---|
49 | min_vec = None |
---|
50 | for vec in self.vectors: |
---|
51 | dist = vector_distance(x, vec) |
---|
52 | if dist < min_val: |
---|
53 | min_val = dist |
---|
54 | min_vec = vec |
---|
55 | |
---|
56 | # find corresponding filename |
---|
57 | index = list(self.vectors).index(min_vec) |
---|
58 | file = self.files[index] |
---|
59 | ret_list.append(file) |
---|
60 | return np.array(ret_list) |
---|
61 | |
---|
62 | |
---|
63 | def main(): |
---|
64 | """ |
---|
65 | Run some basic tests |
---|
66 | """ |
---|
67 | print("Running some tests") |
---|
68 | predictor = HelpPagePredictor() |
---|
69 | vector1 = {"hello":1, "bye":4, "pizza": 10} |
---|
70 | vector2 = {"hello":2, "bye":3, "pizza": 1} |
---|
71 | vector3 = {"hello":3, "bye":9, "pizza": 3} |
---|
72 | |
---|
73 | vectors = np.array([vector1, vector2, vector3]) |
---|
74 | files = np.array(["file1", "file2", "file3"]) |
---|
75 | print(vectors) |
---|
76 | print(files) |
---|
77 | |
---|
78 | testvec = {"hello":1, "bye":1, "pizza": 1} |
---|
79 | |
---|
80 | print("distance to 1") |
---|
81 | print(vector_distance(testvec, vector1)) |
---|
82 | print() |
---|
83 | print("distance to 2") |
---|
84 | print(vector_distance(testvec, vector2)) |
---|
85 | print() |
---|
86 | print("distance to 3") |
---|
87 | print(vector_distance(testvec, vector3)) |
---|
88 | print() |
---|
89 | |
---|
90 | predictor.fit(vectors, files) |
---|
91 | prediction = predictor.predict(np.array([testvec])) |
---|
92 | print(prediction) |
---|
93 | |
---|
94 | dictionary = read_dictionary(KEYWORDS_FILE) |
---|
95 | vectors, file_list = create_table(dictionary=dictionary) |
---|
96 | test_vec = count_occurances("extract.lib", dictionary) |
---|
97 | predictor.fit(vectors, file_list) |
---|
98 | prediction = predictor.predict(np.array([test_vec])) |
---|
99 | print(prediction) |
---|
100 | |
---|
101 | if __name__ == '__main__': |
---|
102 | main() |
---|