source: git/machine_learning/tests/model/test_predictor.py @ 50e872

spielwiese
Last change on this file since 50e872 was 50e872, checked in by Murray Heymann <heymann.murray@…>, 5 years ago
Finish migrating to unittest framework
  • Property mode set to 100644
File size: 1.5 KB
Line 
1import os
2import unittest
3import numpy as np
4
5from model.predictor import *
6from common.constants import KEYWORDS_FILE
7
8class TestPredictionMethods(unittest.TestCase):
9
10    def test_fit(self):
11        predictor = HelpPagePredictor()
12
13        self.assertRaises(AssertionError,
14                          predictor.fit,
15                          None,
16                          np.array([]))
17
18        self.assertRaises(AssertionError,
19                          predictor.fit,
20                          np.array([]),
21                          None)
22        predictor.fit(np.array([]),np.array([]))
23
24    def test_predict(self):
25        predictor = HelpPagePredictor()
26        vector1 = normalise_vector(np.array([1, 4, 10]))
27        vector2 = normalise_vector(np.array([2, 3, 1]))
28        vector3 = normalise_vector(np.array([3, 9, 3]))
29
30        vectors = np.array([vector1, vector2, vector3])
31        files = np.array(["file1", "file2", "file3"])
32
33        testvec = normalise_vector(np.array([1, 1, 1]))
34
35        print("distance to 1")
36        print(vector_distance(testvec, vector1))
37        print()
38        print("distance to 2")
39        print(vector_distance(testvec, vector2))
40        print()
41        print("distance to 3")
42        print(vector_distance(testvec, vector3))
43        print()
44
45        predictor.fit(vectors, files)
46        prediction = predictor.predict(np.array([testvec]))
47        print("Prediction:")
48        print(prediction)
49        self.assertEqual(prediction[0], "file2")
50
51
52if __name__ == '__main__':
53    unittest.main()
Note: See TracBrowser for help on using the repository browser.