Changeset 50e872 in git


Ignore:
Timestamp:
Jul 30, 2019, 8:28:06 PM (4 years ago)
Author:
Murray Heymann <heymann.murray@…>
Branches:
(u'jengelh-datetime', 'ceac47cbc86fe4a15902392bdbb9bd2ae0ea02c6')(u'spielwiese', '0604212ebb110535022efecad887940825b97c3f')
Children:
fece1392f8e9ff07b64d0ed4e5ec57bfa6dbf258
Parents:
da892581a52069935f084604d05c0ecd6d19d5c9
Message:
Finish migrating to unittest framework
Location:
machine_learning
Files:
2 added
1 deleted
6 edited

Legend:

Unmodified
Added
Removed
  • machine_learning/.coveragerc

    rda8925 r50e872  
    11[run]
    22branch = True
    3 omit = tests/*
     3omit = tests/*, *__init__.py, predictor_runner.py
    44
    55
  • machine_learning/common/lookuptable.py

    rda8925 r50e872  
    5757
    5858    # sort alphabetically
    59     dictionary = np.sort(dictionary)
     59    dictionary = np.sort(np.unique(dictionary))
    6060    print(dictionary)
    6161
     
    9999
    100100
    101 def main():
    102     """
    103     Run some tests to check if the functions work.
    104     """
    105     fetch_tbz2_data()
    106     for file in get_list_of_htm_files():
    107         print(file)
    108     extract_keywords()
    109     vectors, files = create_table(attempt_cached=False)
    110     vectors1, files1 = create_table()
    111 
    112     if not (vectors == vectors1).all():
    113         print("Cached version differs from original version")
    114     elif not (files == files1).all():
    115         print("Cached version differs from original version")
    116     else:
    117         print("Cached version corresponds with original")
    118 
    119     dictionary = read_dictionary(KEYWORDS_FILE)
    120     test_vec = count_occurances(os.path.join(HELP_FILE_PATH, "html",
    121                                              files[1]), dictionary)
    122     print((test_vec == vectors[1]).all())
    123 
    124 
    125 if __name__ == '__main__':
    126     main()
  • machine_learning/model/predictor.py

    rda8925 r50e872  
    6666            ret_list.append(file)
    6767        return np.array(ret_list)
    68 
    69 
    70 def basic_vector_tests():
    71     """
    72     Some basic sanity tests
    73     """
    74     predictor = HelpPagePredictor()
    75     vector1 = normalise_vector(np.array([1, 4, 10]))
    76     vector2 = normalise_vector(np.array([2, 3, 1]))
    77     vector3 = normalise_vector(np.array([3, 9, 3]))
    78 
    79     vectors = np.array([vector1, vector2, vector3])
    80     files = np.array(["file1", "file2", "file3"])
    81     print(vectors)
    82     print(files)
    83     print()
    84 
    85     testvec = normalise_vector(np.array([1, 1, 1]))
    86     print("test vector:")
    87     print(testvec)
    88     print()
    89 
    90     print("distance to 1")
    91     print(vector_distance(testvec, vector1))
    92     print()
    93     print("distance to 2")
    94     print(vector_distance(testvec, vector2))
    95     print()
    96     print("distance to 3")
    97     print(vector_distance(testvec, vector3))
    98     print()
    99 
    100     predictor.fit(vectors, files)
    101     prediction = predictor.predict(np.array([testvec]))
    102     print("Prediction:")
    103     print(prediction)
    104     print()
    105 
    106 
    107 def main():
    108     """
    109     Run some basic tests
    110     """
    111     print("Running some tests")
    112 
    113     basic_vector_tests()
    114 
    115     dictionary = read_dictionary(KEYWORDS_FILE)
    116 
    117     start = time.time()
    118     vectors, file_list = create_table(dictionary=dictionary)
    119     end = time.time()
    120     print(end - start, "seconds to create_table")
    121 
    122     predictor = HelpPagePredictor()
    123     predictor.fit(vectors, file_list)
    124 
    125     start = time.time()
    126     test_vec = count_occurances("extract.lib", dictionary)
    127     prediction = predictor.predict(np.array([test_vec]))
    128     end = time.time()
    129     print(end - start, "seconds to make prediction")
    130     print(prediction)
    131     print()
    132 
    133     print("prediction for zero vector")
    134     start = time.time()
    135     zerovec = np.zeros(len(dictionary) - 2)
    136     prediction = predictor.predict(np.array([zerovec]))
    137     end = time.time()
    138     print(end - start, "seconds to make prediction")
    139     print(prediction)
    140     print()
    141 
    142     if len(sys.argv) >= 2:
    143         for i in range(len(sys.argv)):
    144             if i == 0:
    145                 continue
    146             if not os.path.isfile(sys.argv[i]):
    147                 continue
    148             print("predicting for file", sys.argv[i])
    149             start = time.time()
    150             test_vec = count_occurances(sys.argv[i], dictionary)
    151             prediction = predictor.predict(np.array([test_vec]))
    152             end = time.time()
    153             print(end - start, "seconds to make prediction")
    154             print(prediction)
    155             print()
    156 
    157 
    158 if __name__ == '__main__':
    159     #cProfile.run("main()")
    160     main()
  • machine_learning/requirements.txt

    rda8925 r50e872  
     1coverage==4.5.4
    12numpy==1.16.1
    23#pandas==0.24.0
    34pylint==2.3.0
     5pytest==5.0.1
     6pytest-cov==2.7.1
    47scikit-learn==0.20.2
    58scipy==1.2.0
  • machine_learning/tests/common/test_keyword_vectors.py

    rda8925 r50e872  
    130130
    131131
    132 
    133132if __name__ == '__main__':
    134133    unittest.main()
  • machine_learning/tests/common/test_lookuptable.py

    rda8925 r50e872  
     1import os
     2import unittest
     3import numpy as np
     4
     5from common.lookuptable import *
     6from common.constants import KEYWORDS_FILE
     7
     8class TestLookuptableMethods(unittest.TestCase):
     9
     10    def test_get_list_of_htm_files(self):
     11        os.system("rm -r " + HELP_FILE_PATH)
     12        fetch_tbz2_data()
     13        fetch_tbz2_data()
     14        files = get_list_of_htm_files()
     15        self.assertGreater(len(files), 0)
     16
     17    def test_extract_keywords(self):
     18        extract_keywords()
     19        self.assertTrue(os.path.isfile(KEYWORDS_FILE))
     20
     21    def test_create_table(self):
     22        dictionary = read_dictionary(KEYWORDS_FILE)
     23        vectors, files = create_table(dictionary, attempt_cached=False)
     24        vectors1, files1 = create_table()
     25        self.assertTrue((vectors == vectors1).all())
     26        self.assertTrue((files == files1).all())
     27
     28        dictionary = read_dictionary(KEYWORDS_FILE)
     29        test_vec = count_occurances(os.path.join(HELP_FILE_PATH, "html",
     30                                                 files[1]), dictionary)
     31        self.assertTrue((test_vec == vectors[1]).all())
     32
     33if __name__ == '__main__':
     34    unittest.main()
Note: See TracChangeset for help on using the changeset viewer.