Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0 2 : /** 3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com> 4 : * 5 : * @file data_producer.h 6 : * @date 09 July 2021 7 : * @brief This file contains data producer interface 8 : * @see https://github.com/nnstreamer/nntrainer 9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com> 10 : * @bug No known bugs except for NYI items 11 : * 12 : */ 13 : #ifndef __DATA_PRODUCER_H__ 14 : #define __DATA_PRODUCER_H__ 15 : 16 : #include <functional> 17 : #include <limits> 18 : #include <string> 19 : #include <tuple> 20 : #include <vector> 21 : 22 : #include <common.h> 23 : #include <tensor.h> 24 : #include <tensor_dim.h> 25 : namespace nntrainer { 26 : 27 : class Exporter; 28 : 29 : /** 30 : * @brief DataProducer interface used to abstract data provider 31 : * 32 : */ 33 221 : class DataProducer { 34 : public: 35 : /** 36 : * @brief generator callable type which will fill a sample 37 : * @param[in] index current index with range of [0, size() - 1]. If 38 : * size() == SIZE_UNDEFINED, this parameter can be ignored 39 : * @param[out] inputs allocate tensor before expected to be filled by this 40 : * function 41 : * @param[out] labels allocate tensor before expected to be filled by this 42 : * function function. 43 : * @return bool true if this is the last sample, samples will NOT be ignored 44 : * and should be used, or passed at will of caller 45 : * 46 : */ 47 : using Generator = std::function<bool(unsigned int, /** index */ 48 : std::vector<Tensor> & /** inputs */, 49 : std::vector<Tensor> & /** labels */)>; 50 : 51 : constexpr inline static unsigned int SIZE_UNDEFINED = 52 : std::numeric_limits<unsigned int>::max(); 53 : 54 : /** 55 : * @brief Destroy the Data Loader object 56 : * 57 : */ 58 219 : virtual ~DataProducer() {} 59 : 60 : /** 61 : * @brief Get the producer type 62 : * @return const std::string type representation 63 : */ 64 : virtual const std::string getType() const = 0; 65 : 66 : /** 67 : * @brief Set the Property object 68 : * 69 : * @param properties properties to set 70 : */ 71 0 : virtual void setProperty(const std::vector<std::string> &properties) { 72 0 : if (!properties.empty()) { 73 0 : throw std::invalid_argument("There are unparsed properties"); 74 : } 75 0 : } 76 : 77 : /** 78 : * @brief finalize the class to return an immutable Generator. 79 : * @remark this function must assume that the batch dimension of each tensor 80 : * dimension is one. If actual dimension is not one, this function must ignore 81 : * the batch dimension and assume it to be one. 82 : * @param input_dims input dimensions. 83 : * @param label_dims label dimensions. 84 : * @param user_data user data to be used when finalize. 85 : * @return Generator generator is a function that generates a sample upon 86 : * call. 87 : */ 88 0 : virtual Generator finalize(const std::vector<TensorDim> &input_dims, 89 : const std::vector<TensorDim> &label_dims, 90 : void *user_data = nullptr) { 91 0 : return Generator(); 92 : } 93 : 94 : /** 95 : * @brief get the number of samples inside the dataset, if size 96 : * cannot be determined, this function must return. 97 : * DataProducer::SIZE_UNDEFINED. 98 : * @remark this function must assume that the batch dimension of each tensor 99 : * dimension is one. If actual dimension is not one, this function must ignore 100 : * the batch dimension and assume it to be one 101 : * @param input_dims input dimensions 102 : * @param label_dims label dimensions 103 : * 104 : * @return size calculated size 105 : */ 106 1331 : virtual unsigned int size(const std::vector<TensorDim> &input_dims, 107 : const std::vector<TensorDim> &label_dims) const { 108 1331 : return SIZE_UNDEFINED; 109 : } 110 : 111 : /** 112 : * @brief this function helps exporting the dataproducer in a predefined 113 : * format, while workarounding issue caused by templated function type eraser 114 : * 115 : * @param exporter exporter that conatins exporting logic 116 : * @param method enum value to identify how it should be exported to 117 : */ 118 0 : virtual void exportTo(Exporter &exporter, 119 0 : const ml::train::ExportMethods &method) const {} 120 : 121 : /** 122 : * @brief denote if given producer is thread safe and can be parallelized. 123 : * @note if size() == SIZE_UNDEFIEND, thread safe shall be false 124 : * 125 : * @return bool true if thread safe. 126 : */ 127 0 : virtual bool isMultiThreadSafe() const { return false; } 128 : }; 129 : } // namespace nntrainer 130 : #endif // __DATA_PRODUCER_H__