Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : /** 3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com> 4 : * 5 : * @file optimizer.h 6 : * @date 14 October 2020 7 : * @see https://github.com/nnstreamer/nntrainer 8 : * @author Jijoong Moon <jijoong.moon@samsung.com> 9 : * @author Parichay Kapoor <pk.kapoor@samsung.com> 10 : * @bug No known bugs except for NYI items 11 : * @brief This is optimizers interface for c++ API 12 : * 13 : * @note This is experimental API and not stable. 14 : */ 15 : 16 : #ifndef __ML_TRAIN_OPTIMIZER_H__ 17 : #define __ML_TRAIN_OPTIMIZER_H__ 18 : 19 : #if __cplusplus >= MIN_CPP_VERSION 20 : 21 : #include <string> 22 : #include <vector> 23 : 24 : #include <common.h> 25 : 26 : namespace ml { 27 : namespace train { 28 : 29 : /** forward declaration */ 30 : class LearningRateScheduler; 31 : 32 : /** 33 : * @brief Enumeration of optimizer type 34 : */ 35 : enum OptimizerType { 36 : ADAM = ML_TRAIN_OPTIMIZER_TYPE_ADAM, /** adam */ 37 : SGD = ML_TRAIN_OPTIMIZER_TYPE_SGD, /** sgd */ 38 : UNKNOWN = ML_TRAIN_OPTIMIZER_TYPE_UNKNOWN /** unknown */ 39 : }; 40 : 41 : /** 42 : * @class Optimizer Base class for optimizers 43 : * @brief Base class for all optimizers 44 : */ 45 864 : class Optimizer { 46 : public: 47 : /** 48 : * @brief Destructor of Optimizer Class 49 : */ 50 863 : virtual ~Optimizer() = default; 51 : 52 : /** 53 : * @brief get Optimizer Type 54 : * @retval Optimizer type 55 : */ 56 : virtual const std::string getType() const = 0; 57 : 58 : /** 59 : * @brief Default allowed properties 60 : * Available for all optimizers 61 : * - learning_rate : float 62 : * 63 : * Available for SGD and Adam optimizers 64 : * - decay_rate : float, 65 : * - decay_steps : float, 66 : * 67 : * Available for Adam optimizer 68 : * - beta1 : float, 69 : * - beta2 : float, 70 : * - epsilon : float, 71 : */ 72 : 73 : /** 74 : * @brief set Optimizer Parameters 75 : * @param[in] values Optimizer Parameter list 76 : * @details This function accepts vector of properties in the format - 77 : * { std::string property_name, void * property_val, ...} 78 : */ 79 : virtual void setProperty(const std::vector<std::string> &values) = 0; 80 : 81 : /** 82 : * @brief Set the Learning Rate Scheduler object 83 : * 84 : * @param lrs the learning rate scheduler object 85 : */ 86 : virtual int setLearningRateScheduler( 87 : std::shared_ptr<ml::train::LearningRateScheduler> lrs) = 0; 88 : }; 89 : 90 : /** 91 : * @brief Factory creator with constructor for optimizer 92 : */ 93 : std::unique_ptr<Optimizer> 94 : createOptimizer(const std::string &type, 95 : const std::vector<std::string> &properties = {}); 96 : 97 : /** 98 : * @brief Factory creator with constructor for optimizer 99 : */ 100 : std::unique_ptr<Optimizer> 101 : createOptimizer(const OptimizerType &type, 102 : const std::vector<std::string> &properties = {}); 103 : 104 : /** 105 : * @brief General Optimizer Factory function to register optimizer 106 : * 107 : * @param props property representation 108 : * @return std::unique_ptr<ml::train::Optimizer> created object 109 : */ 110 : template <typename T, 111 : std::enable_if_t<std::is_base_of<Optimizer, T>::value, T> * = nullptr> 112 : std::unique_ptr<Optimizer> 113 : createOptimizer(const std::vector<std::string> &props = {}) { 114 : std::unique_ptr<Optimizer> ptr = std::make_unique<T>(); 115 : 116 : ptr->setProperty(props); 117 : return ptr; 118 : } 119 : 120 : namespace optimizer { 121 : 122 : /** 123 : * @brief Helper function to create adam optimizer 124 : */ 125 : inline std::unique_ptr<Optimizer> 126 5 : Adam(const std::vector<std::string> &properties = {}) { 127 5 : return createOptimizer(OptimizerType::ADAM, properties); 128 : } 129 : 130 : /** 131 : * @brief Helper function to create sgd optimizer 132 : */ 133 : inline std::unique_ptr<Optimizer> 134 3 : SGD(const std::vector<std::string> &properties = {}) { 135 3 : return createOptimizer(OptimizerType::SGD, properties); 136 : } 137 : 138 : } // namespace optimizer 139 : 140 : /** 141 : * @brief Enumeration of learning rate scheduler type 142 : */ 143 : enum LearningRateSchedulerType { 144 : CONSTANT = ML_TRAIN_LR_SCHEDULER_TYPE_CONSTANT, /**< constant */ 145 : EXPONENTIAL = 146 : ML_TRAIN_LR_SCHEDULER_TYPE_EXPONENTIAL, /**< exponentially decay */ 147 : STEP = ML_TRAIN_LR_SCHEDULER_TYPE_STEP /**< step wise decay */ 148 : }; 149 : 150 : /** 151 : * @class Learning Rate Schedulers Base class 152 : * @brief Base class for all Learning Rate Schedulers 153 : */ 154 884 : class LearningRateScheduler { 155 : 156 : public: 157 : /** 158 : * @brief Destructor of learning rate scheduler Class 159 : */ 160 0 : virtual ~LearningRateScheduler() = default; 161 : 162 : /** 163 : * @brief Default allowed properties 164 : * Constant Learning rate scheduler 165 : * - learning_rate : float 166 : * 167 : * Exponential Learning rate scheduler 168 : * - learning_rate : float 169 : * - decay_rate : float, 170 : * - decay_steps : float, 171 : * 172 : * Step Learning rate scheduler 173 : * - learing_rate : float, float, ... 174 : * - iteration : uint, uint, ... 175 : * 176 : * more to be added 177 : */ 178 : 179 : /** 180 : * @brief set learning rate scheduler properties 181 : * @param[in] values learning rate scheduler properties list 182 : * @details This function accepts vector of properties in the format - 183 : * { std::string property_name = std::string property_val, ...} 184 : */ 185 : virtual void setProperty(const std::vector<std::string> &values) = 0; 186 : 187 : /** 188 : * @brief get learning rate scheduler Type 189 : * @retval learning rate scheduler type 190 : */ 191 : virtual const std::string getType() const = 0; 192 : }; 193 : 194 : /** 195 : * @brief Factory creator with constructor for learning rate scheduler type 196 : */ 197 : std::unique_ptr<ml::train::LearningRateScheduler> 198 : createLearningRateScheduler(const LearningRateSchedulerType &type, 199 : const std::vector<std::string> &properties = {}); 200 : 201 : /** 202 : * @brief Factory creator with constructor for learning rate scheduler 203 : */ 204 : std::unique_ptr<ml::train::LearningRateScheduler> 205 : createLearningRateScheduler(const std::string &type, 206 : const std::vector<std::string> &properties = {}); 207 : 208 : /** 209 : * @brief General LR Scheduler Factory function to create LR Scheduler 210 : * 211 : * @param props property representation 212 : * @return std::unique_ptr<nntrainer::LearningRateScheduler> created object 213 : */ 214 : template <typename T, 215 : std::enable_if_t<std::is_base_of<LearningRateScheduler, T>::value, T> 216 : * = nullptr> 217 : std::unique_ptr<LearningRateScheduler> 218 56 : createLearningRateScheduler(const std::vector<std::string> &props = {}) { 219 56 : std::unique_ptr<LearningRateScheduler> ptr = std::make_unique<T>(); 220 56 : ptr->setProperty(props); 221 56 : return ptr; 222 : } 223 : 224 : namespace optimizer { 225 : namespace learning_rate { 226 : 227 : /** 228 : * @brief Helper function to create constant learning rate scheduler 229 : */ 230 : inline std::unique_ptr<LearningRateScheduler> 231 : Constant(const std::vector<std::string> &properties = {}) { 232 : return createLearningRateScheduler(LearningRateSchedulerType::CONSTANT, 233 : properties); 234 : } 235 : 236 : /** 237 : * @brief Helper function to create exponential learning rate scheduler 238 : */ 239 : inline std::unique_ptr<LearningRateScheduler> 240 2 : Exponential(const std::vector<std::string> &properties = {}) { 241 2 : return createLearningRateScheduler(LearningRateSchedulerType::EXPONENTIAL, 242 2 : properties); 243 : } 244 : 245 : /** 246 : * @brief Helper function to create step learning rate scheduler 247 : */ 248 : inline std::unique_ptr<LearningRateScheduler> 249 : Step(const std::vector<std::string> &properties = {}) { 250 : return createLearningRateScheduler(LearningRateSchedulerType::STEP, 251 : properties); 252 : } 253 : 254 : } // namespace learning_rate 255 : } // namespace optimizer 256 : 257 : } // namespace train 258 : } // namespace ml 259 : 260 : #else 261 : #error "CPP versions c++17 or over are only supported" 262 : #endif // __cpluscplus 263 : #endif // __ML_TRAIN_OPTIMIZER_H__