source: git/machine_learning/model/predictor.py @ 872089

spielwiese
Last change on this file since 872089 was 872089, checked in by Murray Heymann <heymann.murray@…>, 5 years ago
Reorganize source tree, make test prediction
  • Property mode set to 100644
File size: 2.8 KB
Line 
1"""
2Define the predictor class for classifying according to help page.
3"""
4
5# Third party imports
6import numpy as np
7from sklearn.base import BaseEstimator, ClassifierMixin
8
9# Local imports
10from common.keyword_vector import vector_distance, count_occurances, \
11                                    read_dictionary
12from common.lookuptable import create_table
13from common.constants import KEYWORDS_FILE
14
15
16class HelpPagePredictor(BaseEstimator, ClassifierMixin):
17    """
18    Classifier to assign the given input to a Singular helppage.
19    """
20    def __init__(self):
21        """
22        Define attributes
23        """
24        self.vectors = None
25        self.files = None
26
27
28    def fit(self, X, y): # pylint: disable=invalid-name
29        """
30        Setup the correspondence of vectors to help-files
31        """
32        assert X is not None, "Please provide data for X"
33        assert y is not None, "Please provide data for y"
34        self.vectors = X
35        self.files = y
36        return self
37
38
39    def predict(self, X): # pylint: disable=invalid-name
40        """
41        Classify the input vectors
42        """
43        assert X is not None, "Please provide data for X"
44        ret_list = []
45        for x in X: # pylint: disable=invalid-name
46            # find the closest vector
47
48            min_val = float("inf")
49            min_vec = None
50            for vec in self.vectors:
51                dist = vector_distance(x, vec)
52                if dist < min_val:
53                    min_val = dist
54                    min_vec = vec
55
56            # find corresponding filename
57            index = list(self.vectors).index(min_vec)
58            file = self.files[index]
59            ret_list.append(file)
60        return np.array(ret_list)
61
62
63def main():
64    """
65    Run some basic tests
66    """
67    print("Running some tests")
68    predictor = HelpPagePredictor()
69    vector1 = {"hello":1, "bye":4, "pizza": 10}
70    vector2 = {"hello":2, "bye":3, "pizza": 1}
71    vector3 = {"hello":3, "bye":9, "pizza": 3}
72
73    vectors = np.array([vector1, vector2, vector3])
74    files = np.array(["file1", "file2", "file3"])
75    print(vectors)
76    print(files)
77
78    testvec = {"hello":1, "bye":1, "pizza": 1}
79
80    print("distance to 1")
81    print(vector_distance(testvec, vector1))
82    print()
83    print("distance to 2")
84    print(vector_distance(testvec, vector2))
85    print()
86    print("distance to 3")
87    print(vector_distance(testvec, vector3))
88    print()
89
90    predictor.fit(vectors, files)
91    prediction = predictor.predict(np.array([testvec]))
92    print(prediction)
93
94    dictionary = read_dictionary(KEYWORDS_FILE)
95    vectors, file_list = create_table(dictionary=dictionary)
96    test_vec = count_occurances("extract.lib", dictionary)
97    predictor.fit(vectors, file_list)
98    prediction = predictor.predict(np.array([test_vec]))
99    print(prediction)
100
101if __name__ == '__main__':
102    main()
Note: See TracBrowser for help on using the repository browser.