Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : /** 3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com> 4 : * 5 : * @file factory.cpp 6 : * @date 14 October 2020 7 : * @see https://github.com/nnstreamer/nntrainer 8 : * @author Parichay Kapoor <pk.kapoor@samsung.com> 9 : * @bug No known bugs except for NYI items 10 : * @brief This is implementaion for factory builder interface for c++ API 11 : */ 12 : 13 : #include <memory> 14 : #include <string> 15 : #include <vector> 16 : 17 : #include <app_context.h> 18 : #include <databuffer.h> 19 : #include <databuffer_factory.h> 20 : #include <layer.h> 21 : #include <model.h> 22 : #include <neuralnet.h> 23 : #include <nntrainer_error.h> 24 : #include <optimizer.h> 25 : #include <optimizer_wrapped.h> 26 : 27 : namespace ml { 28 : namespace train { 29 : 30 81 : std::unique_ptr<Layer> createLayer(const LayerType &type, 31 : const std::vector<std::string> &properties) { 32 81 : return nntrainer::createLayerNode(type, properties); 33 : } 34 : 35 : /** 36 : * @brief Factory creator with constructor for layer 37 : */ 38 11 : std::unique_ptr<Layer> createLayer(const std::string &type, 39 : const std::vector<std::string> &properties) { 40 11 : return nntrainer::createLayerNode(type, properties); 41 : } 42 : 43 : std::unique_ptr<Optimizer> 44 28 : createOptimizer(const OptimizerType &type, 45 : const std::vector<std::string> &properties) { 46 28 : return nntrainer::createOptimizerWrapped(type, properties); 47 : } 48 : 49 : /** 50 : * @brief Factory creator with constructor for optimizer 51 : */ 52 : std::unique_ptr<Optimizer> 53 300 : createOptimizer(const std::string &type, 54 : const std::vector<std::string> &properties) { 55 300 : return nntrainer::createOptimizerWrapped(type, properties); 56 : } 57 : 58 : /** 59 : * @brief Factory creator with constructor for model 60 : */ 61 70 : std::unique_ptr<Model> createModel(ModelType type, 62 : const std::vector<std::string> &properties) { 63 70 : std::unique_ptr<Model> model; 64 70 : switch (type) { 65 68 : case ModelType::NEURAL_NET: 66 136 : model = std::make_unique<nntrainer::NeuralNetwork>(); 67 68 : break; 68 2 : default: 69 2 : throw std::invalid_argument("This type of model is not yet supported"); 70 : } 71 : 72 68 : model->setProperty(properties); 73 : 74 68 : return model; 75 : } 76 : 77 : /** 78 : * @brief creator by copying the configuration of other model 79 : */ 80 3 : std::unique_ptr<Model> copyConfiguration(Model &from) { 81 3 : std::unique_ptr<nntrainer::NeuralNetwork> model = 82 5 : std::make_unique<nntrainer::NeuralNetwork>(); 83 3 : nntrainer::NeuralNetwork &f = dynamic_cast<nntrainer::NeuralNetwork &>(from); 84 3 : model->copyConfiguration(f); 85 2 : return model; 86 : } 87 : 88 : /** 89 : * @brief Factory creator with constructor for dataset 90 : */ 91 : std::unique_ptr<Dataset> 92 2 : createDataset(DatasetType type, const std::vector<std::string> &properties) { 93 2 : std::unique_ptr<Dataset> dataset = nntrainer::createDataBuffer(type); 94 0 : dataset->setProperty(properties); 95 : 96 0 : return dataset; 97 : } 98 : 99 : std::unique_ptr<Dataset> 100 29 : createDataset(DatasetType type, const char *file, 101 : const std::vector<std::string> &properties) { 102 29 : std::unique_ptr<Dataset> dataset = nntrainer::createDataBuffer(type, file); 103 27 : dataset->setProperty(properties); 104 : 105 27 : return dataset; 106 : } 107 : 108 : /** 109 : * @brief Factory creator with constructor for dataset 110 : */ 111 : std::unique_ptr<Dataset> 112 36 : createDataset(DatasetType type, datagen_cb cb, void *user_data, 113 : const std::vector<std::string> &properties) { 114 36 : std::unique_ptr<Dataset> dataset = 115 36 : nntrainer::createDataBuffer(type, cb, user_data); 116 36 : dataset->setProperty(properties); 117 : 118 36 : return dataset; 119 : } 120 : 121 : /** 122 : * @brief Factory creator with constructor for learning rate scheduler type 123 : */ 124 : std::unique_ptr<ml::train::LearningRateScheduler> 125 22 : createLearningRateScheduler(const LearningRateSchedulerType &type, 126 : const std::vector<std::string> &properties) { 127 22 : auto &ac = nntrainer::AppContext::Global(); 128 22 : return ac.createObject<ml::train::LearningRateScheduler>(type, properties); 129 : } 130 : 131 : /** 132 : * @brief Factory creator with constructor for learning rate scheduler 133 : */ 134 : std::unique_ptr<ml::train::LearningRateScheduler> 135 1 : createLearningRateScheduler(const std::string &type, 136 : const std::vector<std::string> &properties) { 137 1 : auto &ac = nntrainer::AppContext::Global(); 138 1 : return ac.createObject<ml::train::LearningRateScheduler>(type, properties); 139 : } 140 : 141 : } // namespace train 142 : } // namespace ml