40 #define SINK_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static, flexible }")
45 #define SRC_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static}")
50 static GstStaticPadTemplate
sink_template = GST_STATIC_PAD_TEMPLATE (
"sink",
58 static GstStaticPadTemplate
src_template = GST_STATIC_PAD_TEMPLATE (
"src",
64 #define GST_CAT_DEFAULT gst_tensor_trainer_debug
65 #define gst_tensor_trainer_parent_class parent_class
79 #define MODEL_STATS_SIZE 4
84 #define DEFAULT_PROP_INPUT_LIST 1
85 #define DEFAULT_PROP_LABEL_LIST 1
86 #define DEFAULT_PROP_TRAIN_SAMPLES 0
87 #define DEFAULT_PROP_VALID_SAMPLES 0
88 #define DEFAULT_PROP_EPOCHS 1
92 #define DEFAULT_STR_PROP_VALUE ""
112 const GValue * value, GParamSpec * pspec);
114 GValue * value, GParamSpec * pspec);
117 GstObject * parent, GstEvent * event);
119 GstObject * parent, GstQuery * query);
121 GstObject * parent, GstQuery * query);
123 GstObject * parent, GstBuffer * inbuf);
125 GstPad * pad, GstCaps * filter);
127 element, GstStateChange transition);
130 const GValue * value);
132 * trainer,
const GValue * value);
134 const GValue * value);
136 const GValue * value);
142 guint index, gboolean is_input);
157 GObjectClass *gobject_class;
158 GstElementClass *gstelement_class;
161 "Tensor trainer to train neural network model");
163 gobject_class = G_OBJECT_CLASS (klass);
164 gstelement_class = GST_ELEMENT_CLASS (klass);
166 gobject_class->set_property =
168 gobject_class->get_property =
173 gstelement_class->change_state =
178 g_param_spec_string (
"framework",
"Framework",
179 "(not nullable) Neural network framework to be used for model training, ",
181 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
182 G_PARAM_STATIC_STRINGS));
185 g_param_spec_string (
"model-config",
"Model configuration file path",
186 "(not nullable) Model configuration file is used to configure the model "
187 "to be trained in neural network framework, set the file path",
189 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
190 G_PARAM_STATIC_STRINGS));
193 g_param_spec_string (
"model-save-path",
"Model save path",
194 "(not nullable) Path to save the trained model in framework, if model-config "
195 "contains information about the save file, it is ignored",
197 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
198 G_PARAM_STATIC_STRINGS));
201 g_param_spec_string (
"model-load-path",
"Model load path",
202 "(nullable) Path to a model file to be loaded for the given training session.",
204 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
205 G_PARAM_STATIC_STRINGS));
208 g_param_spec_uint (
"num-inputs",
"Number of inputs",
209 "An input in a tensor can have one or more features data,"
211 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
212 G_PARAM_STATIC_STRINGS));
215 g_param_spec_uint (
"num-labels",
"Number of labels",
216 "A label in a tensor can have one or more classes data,"
218 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
219 G_PARAM_STATIC_STRINGS));
222 g_param_spec_uint (
"num-training-samples",
"Number of training samples",
223 "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
224 ", set how many samples are taken for training model",
226 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
227 G_PARAM_STATIC_STRINGS));
230 g_param_spec_uint (
"num-validation-samples",
231 "Number of validation samples",
232 "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
233 ", set how many samples are taken for validation model",
235 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
236 G_PARAM_STATIC_STRINGS));
238 g_object_class_install_property (gobject_class,
PROP_EPOCHS,
239 g_param_spec_uint (
"epochs",
"Number of epoch",
240 "Epochs are repetitions of training samples and validation samples, "
241 "number of samples received for model training is "
242 "(num-training-samples+num-validation-samples)*epochs", 0, G_MAXINT,
244 G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
245 G_PARAM_STATIC_STRINGS));
247 gst_element_class_set_details_simple (gstelement_class,
"TensorTrainer",
248 "Trainer/Tensor",
"Train tensor data using NN Frameworks",
249 "Samsung Electronics Co., Ltd.");
252 gst_element_class_add_pad_template (gstelement_class,
254 gst_element_class_add_pad_template (gstelement_class,
264 GST_DEBUG (
"<ENTER>");
267 gst_pad_set_event_function (trainer->
sinkpad,
269 gst_pad_set_query_function (trainer->
sinkpad,
271 gst_pad_set_chain_function (trainer->
sinkpad,
273 GST_PAD_SET_PROXY_CAPS (trainer->
sinkpad);
274 gst_element_add_pad (GST_ELEMENT (trainer), trainer->
sinkpad);
278 gst_pad_set_query_function (trainer->
srcpad,
280 GST_PAD_SET_PROXY_CAPS (trainer->
srcpad);
281 gst_element_add_pad (GST_ELEMENT (trainer), trainer->
srcpad);
343 G_OBJECT_CLASS (parent_class)->finalize (
object);
351 const GValue * value, GParamSpec * pspec)
386 G_OBJECT_WARN_INVALID_PROPERTY_ID (
object, prop_id, pspec);
396 GValue * value, GParamSpec * pspec)
431 G_OBJECT_WARN_INVALID_PROPERTY_ID (
object, prop_id, pspec);
442 g_return_val_if_fail (trainer != NULL,
FALSE);
452 GST_ERROR_OBJECT (trainer,
"Check for invalid param value");
458 (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
459 GST_ERROR_OBJECT (trainer,
"Model config file does not exist. [%s]",
476 g_return_val_if_fail (trainer != NULL, NULL);
487 GST_INFO_OBJECT (trainer,
"cur_epoch_data_cnt=%u",
489 GST_INFO_OBJECT (trainer,
"num_tensors=%d",
497 GST_ERROR_OBJECT (trainer,
"Failed to push dummy data");
512 static GstStateChangeReturn
514 GstStateChange transition)
517 GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS;
519 switch (transition) {
520 case GST_STATE_CHANGE_NULL_TO_READY:
521 GST_INFO_OBJECT (trainer,
"NULL_TO_READY");
526 case GST_STATE_CHANGE_READY_TO_PAUSED:
527 GST_INFO_OBJECT (trainer,
"READY_TO_PAUSED");
530 case GST_STATE_CHANGE_PAUSED_TO_PLAYING:
531 GST_INFO_OBJECT (trainer,
"PAUSED_TO_PLAYING");
533 goto state_change_failed;
536 goto state_change_failed;
546 ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition);
548 switch (transition) {
549 case GST_STATE_CHANGE_PLAYING_TO_PAUSED:
550 GST_INFO_OBJECT (trainer,
"PLAYING_TO_PAUSED");
553 if (!g_strcmp0 (trainer->
fw_name,
"nntrainer")) {
554 GST_INFO_OBJECT (trainer,
"cur_epoch_data_cnt=%u",
557 g_thread_new (
"dumy_data_generation_func",
564 case GST_STATE_CHANGE_PAUSED_TO_READY:
565 GST_INFO_OBJECT (trainer,
"PAUSED_TO_READY");
569 case GST_STATE_CHANGE_READY_TO_NULL:
570 GST_INFO_OBJECT (trainer,
"READY_TO_NULL");
581 GST_ERROR_OBJECT (trainer,
"state change failed");
583 return GST_STATE_CHANGE_FAILURE;
592 g_return_if_fail (trainer != NULL);
596 GST_INFO_OBJECT (trainer,
"wait for epoch_completion_cond signal");
611 g_return_val_if_fail (trainer != NULL,
FALSE);
612 g_return_val_if_fail (trainer->
fw != NULL,
FALSE);
613 g_return_val_if_fail (&trainer->
prop != NULL,
FALSE);
634 GST_WARNING_OBJECT (trainer,
635 "Training is completed, buffer is dropped, please change state of pipeline");
668 gsize header_size = 0;
671 GST_ERROR_OBJECT (trainer,
"Invalid Flexible tensors");
677 GST_INFO (
"flexible header size:%zd", header_size);
688 gboolean in_flexible)
695 gsize header_size = 0, expected;
703 GST_DEBUG_OBJECT (trainer,
"num_tensors: %u",
706 GST_ERROR_OBJECT (trainer,
707 "Invalid memory blocks (%u), number of input tensors may be (%u)",
713 for (i = 0; i < n; i++) {
715 if (!gst_memory_map (in_mem[i], &in_info[i], GST_MAP_READ)) {
716 GST_ERROR_OBJECT (trainer,
"Could not map in_mem[%u] GstMemory", i);
723 &in_meta[i], info, in_info[i].
data);
724 if (header_size == 0)
736 GST_ERROR_OBJECT (trainer,
737 "Invalid tensor size (%u'th memory chunk: %zd), expected size (%zd)",
747 GST_ERROR_OBJECT (trainer,
"push error");
752 for (i = 0; i < n; i++) {
754 gst_memory_unmap (in_mem[i], &in_info[i]);
755 gst_memory_unref (in_mem[i]);
778 GST_ERROR_OBJECT (trainer,
"Failed to Get status from sub-plugin.(%s).",
792 GST_DEBUG_OBJECT (trainer,
793 "#%u/%u epochs [training_loss: %f, training_accuracy: %f, validation_loss: %f, validation_accuracy: %f]",
810 { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
815 gboolean created =
FALSE;
818 GST_ERROR_OBJECT (trainer,
819 "The number of output tensors (%u) exceeds limit (%d)",
824 outbuf = gst_buffer_new ();
833 out_mem = gst_allocator_alloc (NULL, data_size, NULL);
835 GST_ERROR_OBJECT (trainer,
"Failed to allocate memory");
839 if (!gst_memory_map (out_mem, &out_info, GST_MAP_WRITE)) {
840 GST_ERROR_OBJECT (trainer,
"Could not map out_mem[%u] GstMemory", i);
841 gst_memory_unref (out_mem);
845 memcpy (out_info.data, model_stats, sizeof (model_stats));
846 gst_memory_unmap (out_mem, &out_info);
855 GST_INFO (
"out_buffer size : %zd", gst_buffer_get_size (outbuf));
857 gst_buffer_unref (outbuf);
872 GstBuffer *outbuf = NULL;
873 GstFlowReturn ret = GST_FLOW_ERROR;
875 gboolean in_flexible;
904 ret = gst_pad_push (trainer->
srcpad, outbuf);
911 gst_buffer_unref (inbuf);
920 GstPad * pad, GstCaps * filter)
925 g_return_val_if_fail (trainer != NULL, NULL);
926 g_return_val_if_fail (pad != NULL, NULL);
936 GST_DEBUG_OBJECT (trainer,
"caps %" GST_PTR_FORMAT, caps);
937 GST_DEBUG_OBJECT (trainer,
"filter %" GST_PTR_FORMAT, filter);
939 if (caps && filter) {
941 result = gst_caps_intersect_full (filter, caps, GST_CAPS_INTERSECT_FIRST);
942 gst_caps_unref (caps);
946 GST_DEBUG_OBJECT (trainer,
"result caps %" GST_PTR_FORMAT, caps);
957 g_return_if_fail (trainer != NULL);
961 GST_INFO_OBJECT (trainer,
962 "got GST_EVENT_EOS event but training is not completed, state is %d, "
963 "wait for training_completion_cond signal", GST_STATE (trainer));
969 GST_DEBUG_OBJECT (trainer,
"training is completed in sub-plugin[%s]",
983 GST_DEBUG_OBJECT (trainer,
"Received %s event: %" GST_PTR_FORMAT,
984 GST_EVENT_TYPE_NAME (event), event);
986 switch (GST_EVENT_TYPE (event)) {
991 case GST_EVENT_FLUSH_START:
992 GST_INFO_OBJECT (trainer,
"get GST_EVENT_FLUSH_START event");
994 case GST_EVENT_FLUSH_STOP:
995 GST_INFO_OBJECT (trainer,
"get GST_EVENT_FLUSH_STOP event");
1001 GstStructure *structure;
1003 gboolean ret =
FALSE;
1005 gst_event_parse_caps (event, &in_caps);
1006 GST_INFO_OBJECT (trainer,
"[in-caps] : %" GST_PTR_FORMAT, in_caps);
1008 structure = gst_caps_get_structure (in_caps, 0);
1012 gst_event_unref (event);
1028 GST_INFO_OBJECT (trainer,
"[out-caps] : %" GST_PTR_FORMAT, out_caps);
1030 ret = gst_pad_set_caps (trainer->
srcpad, out_caps);
1032 gst_event_unref (event);
1033 gst_caps_unref (out_caps);
1039 return gst_pad_event_default (sinkpad, parent, event);
1052 GST_DEBUG_OBJECT (trainer,
"Received '%s' query: %" GST_PTR_FORMAT,
1053 GST_QUERY_TYPE_NAME (query), query);
1055 switch (GST_QUERY_TYPE (query)) {
1056 case GST_QUERY_CAPS:
1061 GST_DEBUG_OBJECT (trainer,
"[GST_QUERY_CAPS]");
1062 gst_query_parse_caps (query, &filter);
1063 GST_DEBUG_OBJECT (trainer,
"Caps from query : %" GST_PTR_FORMAT, filter);
1067 GST_INFO_OBJECT (trainer,
"[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
1068 gst_query_set_caps_result (query, caps);
1069 gst_caps_unref (caps);
1073 case GST_QUERY_ACCEPT_CAPS:
1076 GstCaps *template_caps;
1079 GST_DEBUG_OBJECT (trainer,
"[GST_QUERY_ACCEPT_CAPS]");
1080 gst_query_parse_accept_caps (query, &caps);
1081 GST_INFO_OBJECT (trainer,
"Accept caps from query : %" GST_PTR_FORMAT,
1084 if (gst_caps_is_fixed (caps)) {
1085 template_caps = gst_pad_get_pad_template_caps (sinkpad);
1086 GST_DEBUG_OBJECT (trainer,
"sinkpad template_caps : %" GST_PTR_FORMAT,
1089 result = gst_caps_can_intersect (template_caps, caps);
1090 gst_caps_unref (template_caps);
1092 GST_DEBUG_OBJECT (trainer,
"intersect caps : %" GST_PTR_FORMAT, caps);
1095 gst_query_set_accept_caps_result (query,
result);
1102 return gst_pad_query_default (sinkpad, parent, query);
1115 GST_DEBUG_OBJECT (trainer,
"Received %s query: %" GST_PTR_FORMAT,
1116 GST_QUERY_TYPE_NAME (query), query);
1118 switch (GST_QUERY_TYPE (query)) {
1119 case GST_QUERY_CAPS:
1123 GST_DEBUG_OBJECT (trainer,
"[GST_QUERY_CAPS]");
1124 gst_query_parse_caps (query, &filter);
1125 GST_DEBUG_OBJECT (trainer,
"Caps from query : %" GST_PTR_FORMAT, filter);
1128 GST_INFO_OBJECT (trainer,
"[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
1129 gst_query_set_caps_result (query, caps);
1130 gst_caps_unref (caps);
1136 return gst_pad_query_default (srcpad, parent, query);
1144 const GValue * value)
1147 trainer->
fw_name = g_value_dup_string (value);
1148 GST_INFO_OBJECT (trainer,
"Framework: %s", trainer->
fw_name);
1158 trainer,
const GValue * value)
1162 GST_INFO_OBJECT (trainer,
"Model configuration file path: %s",
1171 const GValue * value)
1175 GST_INFO_OBJECT (trainer,
"File path to save the model: %s",
1184 const GValue * value)
1188 GST_INFO_OBJECT (trainer,
"File path to load the model: %s",
1200 g_return_val_if_fail (name != NULL,
FALSE);
1201 g_return_val_if_fail (trainer != NULL,
FALSE);
1203 GST_INFO_OBJECT (trainer,
"Try to find framework: %s", name);
1207 GST_ERROR_OBJECT (trainer,
"Can not find framework(%s)", trainer->
fw_name);
1211 GST_INFO_OBJECT (trainer,
"Find framework %s:%p", trainer->
fw_name, fw);
1223 g_return_val_if_fail (trainer != NULL,
FALSE);
1226 GST_ERROR_OBJECT (trainer,
"fw is not opened(%d) or fw is not null(%p)",
1232 GST_ERROR_OBJECT (trainer,
"Could not create framework");
1236 GST_DEBUG_OBJECT (trainer,
"%p", trainer->
privateData);
1240 GST_DEBUG_OBJECT (trainer,
"Success, Framework: %p", trainer->
privateData);
1251 guint index, gboolean is_input)
1262 GST_ERROR_OBJECT (trainer,
"has inconsistent data");
1275 gboolean ret =
TRUE;
1277 g_return_val_if_fail (trainer != NULL,
FALSE);
1278 g_return_val_if_fail (trainer->
fw_name != NULL,
FALSE);
1298 g_return_if_fail (trainer != NULL);
1299 g_return_if_fail (trainer->
fw != NULL);
1311 g_return_if_fail (trainer != NULL);
1312 g_return_if_fail (trainer->
fw != NULL);
1313 g_return_if_fail (trainer->
fw->
start != NULL);
1315 GST_DEBUG_OBJECT (trainer,
"Start model training");
1320 GST_ERROR_OBJECT (trainer,
"Model training is failed");
1332 g_return_if_fail (trainer != NULL);
1333 g_return_if_fail (trainer->
fw != NULL);
1334 g_return_if_fail (trainer->
fw->
stop != NULL);
1336 GST_DEBUG_OBJECT (trainer,
"Stop model training");
1339 GST_ERROR_OBJECT (trainer,
"Stopping model training is failed");
1351 g_return_if_fail (trainer != NULL);
1375 const char *name = NULL;
1378 g_return_val_if_fail (ttsp != NULL, 0);
1384 GST_ERROR (
"getFrameworkInfo() failed");
1402 const char *name = NULL;
1405 g_return_val_if_fail (ttsp != NULL, 0);
1411 GST_ERROR (
"getFrameworkInfo() failed");
1429 g_return_if_fail (notifier != NULL);
1430 g_return_if_fail (type < TRAINER_EVENT_UNKNOWN || type > 0);
1436 GST_DEBUG (
"Received GstTensorTrainerEvent(%d)",
type);
1442 GST_DEBUG (
"send epoch_completion_cond signal");
1449 GST_DEBUG (
"send training_completion_cond signal");