Doxygen Book
gsttensor_trainer.c
Go to the documentation of this file.
1 /* SPDX-License-Identifier: LGPL-2.1-only */
27 #ifdef HAVE_CONFIG_H
28 #include "config.h"
29 #endif
30 #include <stdlib.h>
31 #include <nnstreamer_subplugin.h>
32 #include <nnstreamer_util.h>
33 #include "gsttensor_trainer.h"
34 #include <unistd.h>
35 #include <math.h>
36 
40 #define SINK_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static, flexible }")
41 
45 #define SRC_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static}")
46 
50 static GstStaticPadTemplate sink_template = GST_STATIC_PAD_TEMPLATE ("sink",
51  GST_PAD_SINK,
52  GST_PAD_ALWAYS,
53  GST_STATIC_CAPS (SINK_CAPS_STRING));
54 
58 static GstStaticPadTemplate src_template = GST_STATIC_PAD_TEMPLATE ("src",
59  GST_PAD_SRC,
60  GST_PAD_ALWAYS,
61  GST_STATIC_CAPS (SRC_CAPS_STRING));
62 
63 GST_DEBUG_CATEGORY_STATIC (gst_tensor_trainer_debug);
64 #define GST_CAT_DEFAULT gst_tensor_trainer_debug
65 #define gst_tensor_trainer_parent_class parent_class
66 G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT);
67 
72 enum
73 {
78 };
79 #define MODEL_STATS_SIZE 4
80 
84 #define DEFAULT_PROP_INPUT_LIST 1
85 #define DEFAULT_PROP_LABEL_LIST 1
86 #define DEFAULT_PROP_TRAIN_SAMPLES 0
87 #define DEFAULT_PROP_VALID_SAMPLES 0
88 #define DEFAULT_PROP_EPOCHS 1
89 
92 #define DEFAULT_STR_PROP_VALUE ""
93 
97 enum
98 {
104  PROP_NUM_INPUTS, /* number of input list */
105  PROP_NUM_LABELS, /* number of label list */
106  PROP_NUM_TRAINING_SAMPLES, /* number of training data */
107  PROP_NUM_VALIDATION_SAMPLES, /* number of validation data */
108  PROP_EPOCHS, /* Repetitions of training */
109 };
110 
111 static void gst_tensor_trainer_set_property (GObject * object, guint prop_id,
112  const GValue * value, GParamSpec * pspec);
113 static void gst_tensor_trainer_get_property (GObject * object, guint prop_id,
114  GValue * value, GParamSpec * pspec);
115 static void gst_tensor_trainer_finalize (GObject * object);
116 static gboolean gst_tensor_trainer_sink_event (GstPad * sinkpad,
117  GstObject * parent, GstEvent * event);
118 static gboolean gst_tensor_trainer_sink_query (GstPad * sinkpad,
119  GstObject * parent, GstQuery * query);
120 static gboolean gst_tensor_trainer_src_query (GstPad * srcpad,
121  GstObject * parent, GstQuery * query);
122 static GstFlowReturn gst_tensor_trainer_chain (GstPad * sinkpad,
123  GstObject * parent, GstBuffer * inbuf);
124 static GstCaps *gst_tensor_trainer_query_caps (GstTensorTrainer * trainer,
125  GstPad * pad, GstCaps * filter);
126 static GstStateChangeReturn gst_tensor_trainer_change_state (GstElement *
127  element, GstStateChange transition);
128 
130  const GValue * value);
132  * trainer, const GValue * value);
134  const GValue * value);
136  const GValue * value);
137 static gboolean gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
138  const char *name);
140  trainer);
142  guint index, gboolean is_input);
143 static gboolean gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
145  trainer);
147  trainer);
150 
154 static void
156 {
157  GObjectClass *gobject_class;
158  GstElementClass *gstelement_class;
159 
160  GST_DEBUG_CATEGORY_INIT (GST_CAT_DEFAULT, "tensor_trainer", 0,
161  "Tensor trainer to train neural network model");
162 
163  gobject_class = G_OBJECT_CLASS (klass);
164  gstelement_class = GST_ELEMENT_CLASS (klass);
165 
166  gobject_class->set_property =
167  GST_DEBUG_FUNCPTR (gst_tensor_trainer_set_property);
168  gobject_class->get_property =
169  GST_DEBUG_FUNCPTR (gst_tensor_trainer_get_property);
170  gobject_class->finalize = GST_DEBUG_FUNCPTR (gst_tensor_trainer_finalize);
171 
172  /* Called when the element's state changes */
173  gstelement_class->change_state =
174  GST_DEBUG_FUNCPTR (gst_tensor_trainer_change_state);
175 
176  /* Install properties for tensor_trainer */
177  g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
178  g_param_spec_string ("framework", "Framework",
179  "(not nullable) Neural network framework to be used for model training, ",
181  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
182  G_PARAM_STATIC_STRINGS));
183 
184  g_object_class_install_property (gobject_class, PROP_MODEL_CONFIG,
185  g_param_spec_string ("model-config", "Model configuration file path",
186  "(not nullable) Model configuration file is used to configure the model "
187  "to be trained in neural network framework, set the file path",
189  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
190  G_PARAM_STATIC_STRINGS));
191 
192  g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH,
193  g_param_spec_string ("model-save-path", "Model save path",
194  "(not nullable) Path to save the trained model in framework, if model-config "
195  "contains information about the save file, it is ignored",
197  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
198  G_PARAM_STATIC_STRINGS));
199 
200  g_object_class_install_property (gobject_class, PROP_MODEL_LOAD_PATH,
201  g_param_spec_string ("model-load-path", "Model load path",
202  "(nullable) Path to a model file to be loaded for the given training session.",
204  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
205  G_PARAM_STATIC_STRINGS));
206 
207  g_object_class_install_property (gobject_class, PROP_NUM_INPUTS,
208  g_param_spec_uint ("num-inputs", "Number of inputs",
209  "An input in a tensor can have one or more features data,"
210  "set how many inputs are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
211  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
212  G_PARAM_STATIC_STRINGS));
213 
214  g_object_class_install_property (gobject_class, PROP_NUM_LABELS,
215  g_param_spec_uint ("num-labels", "Number of labels",
216  "A label in a tensor can have one or more classes data,"
217  "set how many labels are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
218  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
219  G_PARAM_STATIC_STRINGS));
220 
221  g_object_class_install_property (gobject_class, PROP_NUM_TRAINING_SAMPLES,
222  g_param_spec_uint ("num-training-samples", "Number of training samples",
223  "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
224  ", set how many samples are taken for training model",
225  0, G_MAXINT, 0,
226  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
227  G_PARAM_STATIC_STRINGS));
228 
229  g_object_class_install_property (gobject_class, PROP_NUM_VALIDATION_SAMPLES,
230  g_param_spec_uint ("num-validation-samples",
231  "Number of validation samples",
232  "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
233  ", set how many samples are taken for validation model",
234  0, G_MAXINT, 0,
235  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
236  G_PARAM_STATIC_STRINGS));
237 
238  g_object_class_install_property (gobject_class, PROP_EPOCHS,
239  g_param_spec_uint ("epochs", "Number of epoch",
240  "Epochs are repetitions of training samples and validation samples, "
241  "number of samples received for model training is "
242  "(num-training-samples+num-validation-samples)*epochs", 0, G_MAXINT,
244  G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
245  G_PARAM_STATIC_STRINGS));
246 
247  gst_element_class_set_details_simple (gstelement_class, "TensorTrainer",
248  "Trainer/Tensor", "Train tensor data using NN Frameworks",
249  "Samsung Electronics Co., Ltd.");
250 
251  /* Add pad template */
252  gst_element_class_add_pad_template (gstelement_class,
253  gst_static_pad_template_get (&src_template));
254  gst_element_class_add_pad_template (gstelement_class,
255  gst_static_pad_template_get (&sink_template));
256 }
257 
261 static void
263 {
264  GST_DEBUG ("<ENTER>");
266  trainer->sinkpad = gst_pad_new_from_static_template (&sink_template, "sink");
267  gst_pad_set_event_function (trainer->sinkpad,
268  GST_DEBUG_FUNCPTR (gst_tensor_trainer_sink_event));
269  gst_pad_set_query_function (trainer->sinkpad,
270  GST_DEBUG_FUNCPTR (gst_tensor_trainer_sink_query));
271  gst_pad_set_chain_function (trainer->sinkpad,
272  GST_DEBUG_FUNCPTR (gst_tensor_trainer_chain));
273  GST_PAD_SET_PROXY_CAPS (trainer->sinkpad);
274  gst_element_add_pad (GST_ELEMENT (trainer), trainer->sinkpad);
275 
277  trainer->srcpad = gst_pad_new_from_static_template (&src_template, "src");
278  gst_pad_set_query_function (trainer->srcpad,
279  GST_DEBUG_FUNCPTR (gst_tensor_trainer_src_query));
280  GST_PAD_SET_PROXY_CAPS (trainer->srcpad);
281  gst_element_add_pad (GST_ELEMENT (trainer), trainer->srcpad);
282 
284  trainer->fw_name = g_strdup (DEFAULT_STR_PROP_VALUE);
285  trainer->prop.model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
286  trainer->prop.model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
287  trainer->prop.model_load_path = NULL;
293 
294  trainer->fw = NULL;
295  trainer->fw_created = FALSE;
296  trainer->is_training_complete = FALSE;
297  trainer->is_epoch_complete = FALSE;
298  trainer->cur_epoch_data_cnt = 0;
299  trainer->required_sample = 0;
300 
303 
304  g_cond_init (&trainer->training_completion_cond);
305  g_mutex_init (&trainer->training_completion_lock);
306  g_cond_init (&trainer->epoch_completion_cond);
307  g_mutex_init (&trainer->epoch_completion_lock);
308 
310 }
311 
315 static void
316 gst_tensor_trainer_finalize (GObject * object)
317 {
318  GstTensorTrainer *trainer;
319  trainer = GST_TENSOR_TRAINER (object);
320 
321  g_free (trainer->fw_name);
322  g_free ((char *) trainer->prop.model_config);
323  g_free ((char *) trainer->prop.model_save_path);
324  g_free ((char *) trainer->prop.model_load_path);
325 
328 
329  g_cond_clear (&trainer->training_completion_cond);
330  g_mutex_clear (&trainer->training_completion_lock);
331  g_cond_clear (&trainer->epoch_completion_cond);
332  g_mutex_clear (&trainer->epoch_completion_lock);
333 
334  if (trainer->dummy_data_thread) {
335  g_thread_join (trainer->dummy_data_thread);
336  trainer->dummy_data_thread = NULL;
337  }
338 
339  if (trainer->fw_created && trainer->fw) {
340  trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
341  }
342 
343  G_OBJECT_CLASS (parent_class)->finalize (object);
344 }
345 
349 static void
350 gst_tensor_trainer_set_property (GObject * object, guint prop_id,
351  const GValue * value, GParamSpec * pspec)
352 {
353  GstTensorTrainer *trainer;
354 
355  trainer = GST_TENSOR_TRAINER (object);
356 
357  switch (prop_id) {
358  case PROP_FRAMEWORK:
359  gst_tensor_trainer_set_prop_framework (trainer, value);
360  break;
361  case PROP_MODEL_CONFIG:
363  break;
366  break;
369  break;
370  case PROP_NUM_INPUTS:
371  trainer->prop.num_inputs = g_value_get_uint (value);
372  break;
373  case PROP_NUM_LABELS:
374  trainer->prop.num_labels = g_value_get_uint (value);
375  break;
377  trainer->prop.num_training_samples = g_value_get_uint (value);
378  break;
380  trainer->prop.num_validation_samples = g_value_get_uint (value);
381  break;
382  case PROP_EPOCHS:
383  trainer->prop.num_epochs = g_value_get_uint (value);
384  break;
385  default:
386  G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
387  break;
388  }
389 }
390 
394 static void
395 gst_tensor_trainer_get_property (GObject * object, guint prop_id,
396  GValue * value, GParamSpec * pspec)
397 {
398  GstTensorTrainer *trainer;
399 
400  trainer = GST_TENSOR_TRAINER (object);
401 
402  switch (prop_id) {
403  case PROP_FRAMEWORK:
404  g_value_set_string (value, trainer->fw_name);
405  break;
406  case PROP_MODEL_CONFIG:
407  g_value_set_string (value, trainer->prop.model_config);
408  break;
410  g_value_set_string (value, trainer->prop.model_save_path);
411  break;
413  g_value_set_string (value, trainer->prop.model_load_path);
414  break;
415  case PROP_NUM_INPUTS:
416  g_value_set_uint (value, trainer->prop.num_inputs);
417  break;
418  case PROP_NUM_LABELS:
419  g_value_set_uint (value, trainer->prop.num_labels);
420  break;
422  g_value_set_uint (value, trainer->prop.num_training_samples);
423  break;
425  g_value_set_uint (value, trainer->prop.num_validation_samples);
426  break;
427  case PROP_EPOCHS:
428  g_value_set_uint (value, trainer->prop.num_epochs);
429  break;
430  default:
431  G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
432  break;
433  }
434 }
435 
439 static gboolean
441 {
442  g_return_val_if_fail (trainer != NULL, FALSE);
443 
444  /* Parameters that can be retrieved from caps will be removed */
445  if (!trainer->fw_name
446  || (g_ascii_strcasecmp (trainer->prop.model_config,
448  || (g_ascii_strcasecmp (trainer->prop.model_save_path,
450  || trainer->prop.num_epochs <= 0 || trainer->prop.num_inputs <= 0
451  || trainer->prop.num_labels <= 0) {
452  GST_ERROR_OBJECT (trainer, "Check for invalid param value");
453 
454  return FALSE;
455  }
456 
457  if (!g_file_test (trainer->prop.model_config,
458  (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
459  GST_ERROR_OBJECT (trainer, "Model config file does not exist. [%s]",
460  trainer->prop.model_config);
461  return FALSE;
462  }
463 
464  return TRUE;
465 }
466 
470 static gpointer
472 {
473  guint i;
474  gint ret = -1;
475  gpointer dummy_data[NNS_TENSOR_SIZE_LIMIT] = { NULL };
476  g_return_val_if_fail (trainer != NULL, NULL);
477 
479 
480  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
481  dummy_data[i] = g_malloc (trainer->input_tensors[i].size);
482  memset (dummy_data[i], 1, trainer->input_tensors[i].size);
483  trainer->input_tensors[i].data = dummy_data[i];
484  }
485 
486  do {
487  GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
488  trainer->cur_epoch_data_cnt);
489  GST_INFO_OBJECT (trainer, "num_tensors=%d",
490  trainer->prop.input_meta.num_tensors);
491 
492  ret =
493  trainer->fw->push_data (trainer->fw, &trainer->prop,
494  trainer->privateData, trainer->input_tensors);
495 
496  if (ret < 0) {
497  GST_ERROR_OBJECT (trainer, "Failed to push dummy data");
498  } else {
499  trainer->cur_epoch_data_cnt++;
500  }
501  } while (trainer->required_sample > trainer->cur_epoch_data_cnt);
502 
503  for (i = 0; i < trainer->output_meta.num_tensors; i++)
504  g_free (dummy_data[i]);
505 
506  return NULL;
507 }
508 
512 static GstStateChangeReturn
513 gst_tensor_trainer_change_state (GstElement * element,
514  GstStateChange transition)
515 {
516  GstTensorTrainer *trainer = GST_TENSOR_TRAINER (element);
517  GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS;
518 
519  switch (transition) {
520  case GST_STATE_CHANGE_NULL_TO_READY:
521  GST_INFO_OBJECT (trainer, "NULL_TO_READY");
522  /* currently not used */
523  trainer->is_training_complete = FALSE;
524  break;
525 
526  case GST_STATE_CHANGE_READY_TO_PAUSED:
527  GST_INFO_OBJECT (trainer, "READY_TO_PAUSED");
528  break;
529 
530  case GST_STATE_CHANGE_PAUSED_TO_PLAYING:
531  GST_INFO_OBJECT (trainer, "PAUSED_TO_PLAYING");
533  goto state_change_failed;
534  if (!trainer->fw_created) {
535  if (!gst_tensor_trainer_create_model (trainer))
536  goto state_change_failed;
537  }
540  break;
541 
542  default:
543  break;
544  }
545 
546  ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition);
547 
548  switch (transition) {
549  case GST_STATE_CHANGE_PLAYING_TO_PAUSED:
550  GST_INFO_OBJECT (trainer, "PLAYING_TO_PAUSED");
551  /* need to generate dummy data */
552  if (!trainer->is_training_complete) {
553  if (!g_strcmp0 (trainer->fw_name, "nntrainer")) {
554  GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
555  trainer->cur_epoch_data_cnt);
556  trainer->dummy_data_thread =
557  g_thread_new ("dumy_data_generation_func",
559  trainer);
560  }
561  }
562  break;
563 
564  case GST_STATE_CHANGE_PAUSED_TO_READY:
565  GST_INFO_OBJECT (trainer, "PAUSED_TO_READY");
566  /* stop model train ? */
567  break;
568 
569  case GST_STATE_CHANGE_READY_TO_NULL:
570  GST_INFO_OBJECT (trainer, "READY_TO_NULL");
571  /* destroy or reset model ? */
572  break;
573 
574  default:
575  break;
576  }
577 
578  return ret;
579 
580 state_change_failed:
581  GST_ERROR_OBJECT (trainer, "state change failed");
582 
583  return GST_STATE_CHANGE_FAILURE;
584 }
585 
589 static void
591 {
592  g_return_if_fail (trainer != NULL);
593 
594  g_mutex_lock (&trainer->epoch_completion_lock);
595  while (!trainer->is_epoch_complete) {
596  GST_INFO_OBJECT (trainer, "wait for epoch_completion_cond signal");
597  g_cond_wait (&trainer->epoch_completion_cond,
598  &trainer->epoch_completion_lock);
599  }
600  trainer->is_epoch_complete = FALSE;
601  g_mutex_unlock (&trainer->epoch_completion_lock);
602 }
603 
608 static gboolean
610 {
611  g_return_val_if_fail (trainer != NULL, FALSE);
612  g_return_val_if_fail (trainer->fw != NULL, FALSE);
613  g_return_val_if_fail (&trainer->prop != NULL, FALSE);
614 
615  trainer->required_sample =
617  if (trainer->cur_epoch_data_cnt != trainer->required_sample)
618  return FALSE;
619 
621  trainer->cur_epoch_data_cnt = 0;
622  return TRUE;
623 }
624 
628 static gboolean
630 {
631  if (trainer->is_training_complete == TRUE) {
634  GST_WARNING_OBJECT (trainer,
635  "Training is completed, buffer is dropped, please change state of pipeline");
636  return TRUE;
637  }
638  return FALSE;
639 }
640 
644 static gboolean
646  guint num_tensors)
647 {
648  if (!trainer->fw_created) {
650  return FALSE;;
651  if (!gst_tensor_trainer_create_model (trainer))
652  return FALSE;
653  }
654 
655  if (num_tensors >= NNS_TENSOR_SIZE_LIMIT)
656  return FALSE;
657 
658  return TRUE;
659 }
660 
664 static gsize
666  GstTensorMetaInfo * meta, GstTensorInfo * info, void *data)
667 {
668  gsize header_size = 0;
669 
671  GST_ERROR_OBJECT (trainer, "Invalid Flexible tensors");
672  return 0;
673  }
674 
675  if (gst_tensor_meta_info_convert (meta, info)) {
676  header_size = gst_tensor_meta_info_get_header_size (meta);
677  GST_INFO ("flexible header size:%zd", header_size);
678  }
679 
680  return header_size;
681 }
682 
686 static gboolean
687 gst_tensor_trainer_push_input (GstTensorTrainer * trainer, GstBuffer * inbuf,
688  gboolean in_flexible)
689 {
690  guint i, n;
691  GstMemory *in_mem[NNS_TENSOR_SIZE_LIMIT] = { 0, };
692  GstMapInfo in_info[NNS_TENSOR_SIZE_LIMIT];
694  GstTensorInfo *info;
695  gsize header_size = 0, expected;
696  gint ret = -1;
697 
698  n = gst_tensor_buffer_get_count (inbuf);
699 
700  if (in_flexible)
701  trainer->prop.input_meta.num_tensors = n;
702  else {
703  GST_DEBUG_OBJECT (trainer, "num_tensors: %u",
704  trainer->prop.input_meta.num_tensors);
705  if (n != trainer->prop.input_meta.num_tensors) {
706  GST_ERROR_OBJECT (trainer,
707  "Invalid memory blocks (%u), number of input tensors may be (%u)",
708  n, trainer->prop.input_meta.num_tensors);
709  goto error;
710  }
711  }
712 
713  for (i = 0; i < n; i++) {
714  in_mem[i] = gst_tensor_buffer_get_nth_memory (inbuf, i);
715  if (!gst_memory_map (in_mem[i], &in_info[i], GST_MAP_READ)) {
716  GST_ERROR_OBJECT (trainer, "Could not map in_mem[%u] GstMemory", i);
717  goto error;
718  }
719 
720  if (in_flexible) {
721  info = gst_tensors_info_get_nth_info (&trainer->prop.input_meta, i);
722  header_size = gst_tensor_trainer_convert_meta (trainer,
723  &in_meta[i], info, in_info[i].data);
724  if (header_size == 0)
725  goto error;
726  }
727 
728  trainer->input_tensors[i].data = in_info[i].data + header_size;
729  trainer->input_tensors[i].size = in_info[i].size - header_size;
730  GST_INFO ("input_tensors[%u].size= %zd", i, trainer->input_tensors[i].size);
731  GST_INFO ("input_tensors[%u].data: %p", i, trainer->input_tensors[i].data);
732 
733  /* Check input tensor size */
734  expected = gst_tensor_trainer_get_tensor_size (trainer, i, TRUE);
735  if (expected != trainer->input_tensors[i].size) {
736  GST_ERROR_OBJECT (trainer,
737  "Invalid tensor size (%u'th memory chunk: %zd), expected size (%zd)",
738  i, trainer->input_tensors[i].size, expected);
739  goto error;
740  }
741  }
742 
743  ret = trainer->fw->push_data (trainer->fw, &trainer->prop,
744  trainer->privateData, trainer->input_tensors);
745 
746  if (ret < 0)
747  GST_ERROR_OBJECT (trainer, "push error");
748  else
749  trainer->cur_epoch_data_cnt++;
750 
751 error:
752  for (i = 0; i < n; i++) {
753  if (in_mem[i]) {
754  gst_memory_unmap (in_mem[i], &in_info[i]);
755  gst_memory_unref (in_mem[i]);
756  }
757 
758  trainer->input_tensors[i].data = NULL;
759  trainer->input_tensors[i].size = 0;
760  }
761 
762  return (ret == 0);
763 }
764 
768 static gboolean
770  double *model_stats)
771 {
772  gint ret = -1;
773 
774  ret =
775  trainer->fw->getStatus (trainer->fw, &trainer->prop,
776  trainer->privateData);
777  if (ret < 0) {
778  GST_ERROR_OBJECT (trainer, "Failed to Get status from sub-plugin.(%s).",
779  trainer->fw_name);
780  return FALSE;
781  }
782  /* If the value is invalid, it is already set by -INFINITY. */
783  if (trainer->prop.training_loss > 0)
784  model_stats[TRAINING_LOSS] = trainer->prop.training_loss;
785  if (trainer->prop.training_accuracy > 0)
786  model_stats[TRAINING_ACCURACY] = trainer->prop.training_accuracy;
787  if (trainer->prop.validation_loss > 0)
788  model_stats[VALIDATION_LOSS] = trainer->prop.validation_loss;
789  if (trainer->prop.validation_accuracy > 0)
790  model_stats[VALIDATION_ACCURACY] = trainer->prop.validation_accuracy;
791 
792  GST_DEBUG_OBJECT (trainer,
793  "#%u/%u epochs [training_loss: %f, training_accuracy: %f, validation_loss: %f, validation_accuracy: %f]",
794  trainer->prop.epoch_count, trainer->prop.num_epochs,
795  model_stats[TRAINING_LOSS], model_stats[TRAINING_ACCURACY],
796  model_stats[VALIDATION_LOSS], model_stats[VALIDATION_ACCURACY]);
797 
798  return TRUE;
799 }
800 
804 static GstBuffer *
806 {
807  guint i;
808  size_t data_size;
809  double model_stats[MODEL_STATS_SIZE] =
810  { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
811  GstBuffer *outbuf;
812  GstMemory *out_mem;
813  GstMapInfo out_info;
814  GstTensorInfo *info;
815  gboolean created = FALSE;
816 
818  GST_ERROR_OBJECT (trainer,
819  "The number of output tensors (%u) exceeds limit (%d)",
821  return NULL;
822  }
823 
824  outbuf = gst_buffer_new ();
825 
826  for (i = 0; i < trainer->output_meta.num_tensors; i++) {
827  if (!gst_tensor_trainer_get_model_stats (trainer, model_stats))
828  goto error;
829 
830  data_size = gst_tensor_trainer_get_tensor_size (trainer, i, FALSE);
831  info = gst_tensors_info_get_nth_info (&trainer->output_meta, i);
832 
833  out_mem = gst_allocator_alloc (NULL, data_size, NULL);
834  if (!out_mem) {
835  GST_ERROR_OBJECT (trainer, "Failed to allocate memory");
836  goto error;
837  }
838 
839  if (!gst_memory_map (out_mem, &out_info, GST_MAP_WRITE)) {
840  GST_ERROR_OBJECT (trainer, "Could not map out_mem[%u] GstMemory", i);
841  gst_memory_unref (out_mem);
842  goto error;
843  }
844 
845  memcpy (out_info.data, model_stats, sizeof (model_stats));
846  gst_memory_unmap (out_mem, &out_info);
847 
848  gst_tensor_buffer_append_memory (outbuf, out_mem, info);
849  }
850 
851  created = TRUE;
852 
853 error:
854  if (created) {
855  GST_INFO ("out_buffer size : %zd", gst_buffer_get_size (outbuf));
856  } else {
857  gst_buffer_unref (outbuf);
858  outbuf = NULL;
859  }
860 
861  return outbuf;
862 }
863 
867 static GstFlowReturn
868 gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
869  GstBuffer * inbuf)
870 {
871  GstTensorTrainer *trainer;
872  GstBuffer *outbuf = NULL;
873  GstFlowReturn ret = GST_FLOW_ERROR;
874  guint num_tensors;
875  gboolean in_flexible;
876 
877  trainer = GST_TENSOR_TRAINER (parent);
878  in_flexible = gst_tensor_pad_caps_is_flexible (sinkpad);
879  num_tensors = gst_tensor_buffer_get_count (inbuf);
880 
881  if (!gst_tensor_trainer_check_chain_conditions (trainer, num_tensors)) {
882  goto error;
883  }
884 
886  ret = GST_FLOW_OK;
887  goto error;
888  }
889 
890  if (!gst_tensor_trainer_push_input (trainer, inbuf, in_flexible)) {
891  goto error;
892  }
893 
899  if (trainer->cur_epoch_data_cnt == 1
901  outbuf = gst_tensor_trainer_create_output (trainer);
902 
903  if (outbuf)
904  ret = gst_pad_push (trainer->srcpad, outbuf);
905  } else {
906  /* Run flow, need more data? */
907  ret = GST_FLOW_OK;
908  }
909 
910 error:
911  gst_buffer_unref (inbuf);
912  return ret;
913 }
914 
918 static GstCaps *
920  GstPad * pad, GstCaps * filter)
921 {
922  GstCaps *caps;
923  GstTensorsConfig *config;
924 
925  g_return_val_if_fail (trainer != NULL, NULL);
926  g_return_val_if_fail (pad != NULL, NULL);
927 
928  /* tensor config info for given pad */
929  if (pad == trainer->sinkpad) {
930  config = &trainer->in_config;
931  } else {
932  config = &trainer->out_config;
933  }
934 
935  caps = gst_tensor_pad_possible_caps_from_config (pad, config);
936  GST_DEBUG_OBJECT (trainer, "caps %" GST_PTR_FORMAT, caps);
937  GST_DEBUG_OBJECT (trainer, "filter %" GST_PTR_FORMAT, filter);
938 
939  if (caps && filter) {
940  GstCaps *result;
941  result = gst_caps_intersect_full (filter, caps, GST_CAPS_INTERSECT_FIRST);
942  gst_caps_unref (caps);
943  caps = result;
944  }
945 
946  GST_DEBUG_OBJECT (trainer, "result caps %" GST_PTR_FORMAT, caps);
947 
948  return caps;
949 }
950 
954 static void
956 {
957  g_return_if_fail (trainer != NULL);
958 
959  g_mutex_lock (&trainer->training_completion_lock);
960  while (!trainer->is_training_complete) {
961  GST_INFO_OBJECT (trainer,
962  "got GST_EVENT_EOS event but training is not completed, state is %d, "
963  "wait for training_completion_cond signal", GST_STATE (trainer));
964  g_cond_wait (&trainer->training_completion_cond,
965  &trainer->training_completion_lock);
966  }
967  g_mutex_unlock (&trainer->training_completion_lock);
968 
969  GST_DEBUG_OBJECT (trainer, "training is completed in sub-plugin[%s]",
970  trainer->fw_name);
971 }
972 
976 static gboolean
977 gst_tensor_trainer_sink_event (GstPad * sinkpad, GstObject * parent,
978  GstEvent * event)
979 {
980  GstTensorTrainer *trainer;
981  trainer = GST_TENSOR_TRAINER (parent);
982 
983  GST_DEBUG_OBJECT (trainer, "Received %s event: %" GST_PTR_FORMAT,
984  GST_EVENT_TYPE_NAME (event), event);
985 
986  switch (GST_EVENT_TYPE (event)) {
987  case GST_EVENT_EOS:
988  if (!trainer->is_training_complete)
990  break;
991  case GST_EVENT_FLUSH_START:
992  GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_START event");
993  break;
994  case GST_EVENT_FLUSH_STOP:
995  GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_STOP event");
996  break;
997  case GST_EVENT_CAPS:
998  {
999  GstCaps *in_caps;
1000  GstCaps *out_caps;
1001  GstStructure *structure;
1002  GstTensorsConfig config;
1003  gboolean ret = FALSE;
1004 
1005  gst_event_parse_caps (event, &in_caps);
1006  GST_INFO_OBJECT (trainer, "[in-caps] : %" GST_PTR_FORMAT, in_caps);
1007 
1008  structure = gst_caps_get_structure (in_caps, 0);
1009  if (!gst_tensors_config_from_structure (&config, structure) ||
1010  !gst_tensors_config_validate (&config)) {
1011  gst_tensors_config_free (&config);
1012  gst_event_unref (event);
1013  return FALSE;
1014  }
1015 
1016  /* copy TensorsInfo from negotiated caps to GstTensorTrainerProperties's input_meta */
1017  gst_tensors_info_copy (&trainer->prop.input_meta, &config.info);
1018 
1019  /* set tensor-config and out caps */
1020  trainer->in_config = config;
1021  trainer->out_config.rate_n = config.rate_n;
1022  trainer->out_config.rate_d = config.rate_d;
1023  gst_tensors_info_copy (&trainer->out_config.info, &trainer->output_meta);
1024 
1025  out_caps =
1027  &trainer->out_config);
1028  GST_INFO_OBJECT (trainer, "[out-caps] : %" GST_PTR_FORMAT, out_caps);
1029 
1030  ret = gst_pad_set_caps (trainer->srcpad, out_caps);
1031 
1032  gst_event_unref (event);
1033  gst_caps_unref (out_caps);
1034  return ret;
1035  }
1036  default:
1037  break;
1038  }
1039  return gst_pad_event_default (sinkpad, parent, event);
1040 }
1041 
1045 static gboolean
1046 gst_tensor_trainer_sink_query (GstPad * sinkpad, GstObject * parent,
1047  GstQuery * query)
1048 {
1049  GstTensorTrainer *trainer;
1050  trainer = GST_TENSOR_TRAINER (parent);
1051 
1052  GST_DEBUG_OBJECT (trainer, "Received '%s' query: %" GST_PTR_FORMAT,
1053  GST_QUERY_TYPE_NAME (query), query);
1054 
1055  switch (GST_QUERY_TYPE (query)) {
1056  case GST_QUERY_CAPS:
1057  {
1058  GstCaps *caps;
1059  GstCaps *filter;
1060 
1061  GST_DEBUG_OBJECT (trainer, "[GST_QUERY_CAPS]");
1062  gst_query_parse_caps (query, &filter);
1063  GST_DEBUG_OBJECT (trainer, "Caps from query : %" GST_PTR_FORMAT, filter);
1064 
1065  caps = gst_tensor_trainer_query_caps (trainer, sinkpad, filter);
1066 
1067  GST_INFO_OBJECT (trainer, "[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
1068  gst_query_set_caps_result (query, caps);
1069  gst_caps_unref (caps);
1070 
1071  return TRUE;
1072  }
1073  case GST_QUERY_ACCEPT_CAPS:
1074  {
1075  GstCaps *caps;
1076  GstCaps *template_caps;
1077  gboolean result = FALSE;
1078 
1079  GST_DEBUG_OBJECT (trainer, "[GST_QUERY_ACCEPT_CAPS]");
1080  gst_query_parse_accept_caps (query, &caps);
1081  GST_INFO_OBJECT (trainer, "Accept caps from query : %" GST_PTR_FORMAT,
1082  caps);
1083 
1084  if (gst_caps_is_fixed (caps)) {
1085  template_caps = gst_pad_get_pad_template_caps (sinkpad);
1086  GST_DEBUG_OBJECT (trainer, "sinkpad template_caps : %" GST_PTR_FORMAT,
1087  template_caps);
1088 
1089  result = gst_caps_can_intersect (template_caps, caps);
1090  gst_caps_unref (template_caps);
1091 
1092  GST_DEBUG_OBJECT (trainer, "intersect caps : %" GST_PTR_FORMAT, caps);
1093  }
1094 
1095  gst_query_set_accept_caps_result (query, result);
1096  return TRUE;
1097  }
1098  default:
1099  break;
1100  }
1101 
1102  return gst_pad_query_default (sinkpad, parent, query);
1103 }
1104 
1108 static gboolean
1109 gst_tensor_trainer_src_query (GstPad * srcpad, GstObject * parent,
1110  GstQuery * query)
1111 {
1112  GstTensorTrainer *trainer;
1113  trainer = GST_TENSOR_TRAINER (parent);
1114 
1115  GST_DEBUG_OBJECT (trainer, "Received %s query: %" GST_PTR_FORMAT,
1116  GST_QUERY_TYPE_NAME (query), query);
1117 
1118  switch (GST_QUERY_TYPE (query)) {
1119  case GST_QUERY_CAPS:
1120  {
1121  GstCaps *caps;
1122  GstCaps *filter;
1123  GST_DEBUG_OBJECT (trainer, "[GST_QUERY_CAPS]");
1124  gst_query_parse_caps (query, &filter);
1125  GST_DEBUG_OBJECT (trainer, "Caps from query : %" GST_PTR_FORMAT, filter);
1126  caps = gst_tensor_trainer_query_caps (trainer, srcpad, filter);
1127 
1128  GST_INFO_OBJECT (trainer, "[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
1129  gst_query_set_caps_result (query, caps);
1130  gst_caps_unref (caps);
1131  return TRUE;
1132  }
1133  default:
1134  break;
1135  }
1136  return gst_pad_query_default (srcpad, parent, query);
1137 }
1138 
1142 static void
1144  const GValue * value)
1145 {
1146  g_free (trainer->fw_name);
1147  trainer->fw_name = g_value_dup_string (value);
1148  GST_INFO_OBJECT (trainer, "Framework: %s", trainer->fw_name);
1149 
1151 }
1152 
1156 static void
1158  trainer, const GValue * value)
1159 {
1160  g_free ((char *) trainer->prop.model_config);
1161  trainer->prop.model_config = g_value_dup_string (value);
1162  GST_INFO_OBJECT (trainer, "Model configuration file path: %s",
1163  trainer->prop.model_config);
1164 }
1165 
1169 static void
1171  const GValue * value)
1172 {
1173  g_free ((char *) trainer->prop.model_save_path);
1174  trainer->prop.model_save_path = g_value_dup_string (value);
1175  GST_INFO_OBJECT (trainer, "File path to save the model: %s",
1176  trainer->prop.model_save_path);
1177 }
1178 
1182 static void
1184  const GValue * value)
1185 {
1186  g_free ((char *) trainer->prop.model_load_path);
1187  trainer->prop.model_load_path = g_value_dup_string (value);
1188  GST_INFO_OBJECT (trainer, "File path to load the model: %s",
1189  trainer->prop.model_load_path);
1190 }
1191 
1195 static gboolean
1197 {
1198  const GstTensorTrainerFramework *fw = NULL;
1199 
1200  g_return_val_if_fail (name != NULL, FALSE);
1201  g_return_val_if_fail (trainer != NULL, FALSE);
1202 
1203  GST_INFO_OBJECT (trainer, "Try to find framework: %s", name);
1204 
1205  fw = get_subplugin (NNS_SUBPLUGIN_TRAINER, name);
1206  if (!fw) {
1207  GST_ERROR_OBJECT (trainer, "Can not find framework(%s)", trainer->fw_name);
1208  return FALSE;
1209  }
1210 
1211  GST_INFO_OBJECT (trainer, "Find framework %s:%p", trainer->fw_name, fw);
1212  trainer->fw = fw;
1213 
1214  return TRUE;
1215 }
1216 
1220 static gboolean
1222 {
1223  g_return_val_if_fail (trainer != NULL, FALSE);
1224 
1225  if (!trainer->fw || trainer->fw_created) {
1226  GST_ERROR_OBJECT (trainer, "fw is not opened(%d) or fw is not null(%p)",
1227  trainer->fw_created, trainer->fw);
1228  return FALSE;
1229  }
1230 
1231  if (!trainer->fw->create) {
1232  GST_ERROR_OBJECT (trainer, "Could not create framework");
1233  return FALSE;
1234  }
1235 
1236  GST_DEBUG_OBJECT (trainer, "%p", trainer->privateData);
1237  if (trainer->fw->create (trainer->fw, &trainer->prop,
1238  &trainer->privateData) >= 0) {
1239  trainer->fw_created = TRUE;
1240  GST_DEBUG_OBJECT (trainer, "Success, Framework: %p", trainer->privateData);
1241  return TRUE;
1242  }
1243  return FALSE;
1244 }
1245 
1249 gsize
1251  guint index, gboolean is_input)
1252 {
1253  GstTensorsInfo *info;
1254 
1255  if (is_input)
1256  info = &trainer->prop.input_meta;
1257  else
1258  info = &trainer->output_meta;
1259 
1260  /* Internal Logic Error: out of bound */
1261  if (index >= info->num_tensors) {
1262  GST_ERROR_OBJECT (trainer, "has inconsistent data");
1263  return 0;
1264  }
1265 
1266  return gst_tensors_info_get_size (info, index);
1267 }
1268 
1272 static gboolean
1274 {
1275  gboolean ret = TRUE;
1276 
1277  g_return_val_if_fail (trainer != NULL, FALSE);
1278  g_return_val_if_fail (trainer->fw_name != NULL, FALSE);
1279 
1280  ret = gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
1281  if (!ret)
1282  return ret;
1283 
1284  if (trainer->fw) {
1285  /* model create and compile */
1286  ret = gst_tensor_trainer_create_framework (trainer);
1287  }
1288 
1289  return ret;
1290 }
1291 
1295 static void
1297 {
1298  g_return_if_fail (trainer != NULL);
1299  g_return_if_fail (trainer->fw != NULL);
1300 
1301  trainer->notifier.notifier = (void *) trainer;
1302 }
1303 
1307 static void
1309 {
1310  gint ret = -1;
1311  g_return_if_fail (trainer != NULL);
1312  g_return_if_fail (trainer->fw != NULL);
1313  g_return_if_fail (trainer->fw->start != NULL);
1314 
1315  GST_DEBUG_OBJECT (trainer, "Start model training");
1316  ret =
1317  trainer->fw->start (trainer->fw, &trainer->prop, &trainer->notifier,
1318  trainer->privateData);
1319  if (ret != 0) {
1320  GST_ERROR_OBJECT (trainer, "Model training is failed");
1321  }
1322 }
1323 
1327 static void
1329 {
1330  gint ret = -1;
1331 
1332  g_return_if_fail (trainer != NULL);
1333  g_return_if_fail (trainer->fw != NULL);
1334  g_return_if_fail (trainer->fw->stop != NULL);
1335 
1336  GST_DEBUG_OBJECT (trainer, "Stop model training");
1337  ret = trainer->fw->stop (trainer->fw, &trainer->prop, &trainer->privateData);
1338  if (ret != 0) {
1339  GST_ERROR_OBJECT (trainer, "Stopping model training is failed");
1340  }
1341 }
1342 
1346 static void
1348 {
1349  GstTensorInfo *info;
1350 
1351  g_return_if_fail (trainer != NULL);
1352 
1353  gst_tensors_info_init (&trainer->output_meta);
1354  info = gst_tensors_info_get_nth_info (&trainer->output_meta, 0);
1355 
1356  info->type = _NNS_FLOAT64;
1357  info->dimension[0] = 1;
1358  info->dimension[1] = 1;
1359  info->dimension[2] = 4;
1360  info->dimension[3] = 1;
1361 
1362  trainer->output_meta.num_tensors = 1;
1363 }
1364 
1370 int
1372 {
1375  const char *name = NULL;
1376  int ret = 0;
1377 
1378  g_return_val_if_fail (ttsp != NULL, 0);
1379 
1380  memset (&prop, 0, sizeof (GstTensorTrainerProperties));
1381  gst_tensors_info_init (&prop.input_meta);
1382 
1383  if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
1384  GST_ERROR ("getFrameworkInfo() failed");
1385  return FALSE;
1386  }
1387  name = info.name;
1388 
1389  return register_subplugin (NNS_SUBPLUGIN_TRAINER, name, ttsp);
1390 }
1391 
1397 int
1399 {
1402  const char *name = NULL;
1403  int ret = 0;
1404 
1405  g_return_val_if_fail (ttsp != NULL, 0);
1406 
1407  memset (&prop, 0, sizeof (GstTensorTrainerProperties));
1408  gst_tensors_info_init (&prop.input_meta);
1409 
1410  if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
1411  GST_ERROR ("getFrameworkInfo() failed");
1412  return FALSE;
1413  }
1414  name = info.name;
1415 
1417 }
1418 
1424 void
1427 {
1428  GstTensorTrainer *trainer;
1429  g_return_if_fail (notifier != NULL);
1430  g_return_if_fail (type < TRAINER_EVENT_UNKNOWN || type > 0);
1431  UNUSED (data);
1432 
1433  trainer = (GstTensorTrainer *) notifier->notifier;
1434  g_return_if_fail (GST_IS_TENSOR_TRAINER (trainer));
1435 
1436  GST_DEBUG ("Received GstTensorTrainerEvent(%d)", type);
1437 
1438  switch (type) {
1440  g_mutex_lock (&trainer->epoch_completion_lock);
1441  trainer->is_epoch_complete = TRUE;
1442  GST_DEBUG ("send epoch_completion_cond signal");
1443  g_cond_signal (&trainer->epoch_completion_cond);
1444  g_mutex_unlock (&trainer->epoch_completion_lock);
1445  break;
1447  g_mutex_lock (&trainer->training_completion_lock);
1448  trainer->is_training_complete = TRUE;
1449  GST_DEBUG ("send training_completion_cond signal");
1450  g_cond_signal (&trainer->training_completion_cond);
1451  g_mutex_unlock (&trainer->training_completion_lock);
1452  break;
1453  default:
1454  break;
1455  }
1456 }
gst_tensor_trainer_create_framework
static gboolean gst_tensor_trainer_create_framework(GstTensorTrainer *trainer)
Create NN framework.
Definition: gsttensor_trainer.c:1221
gst_tensor_trainer_set_prop_framework
static void gst_tensor_trainer_set_prop_framework(GstTensorTrainer *trainer, const GValue *value)
Handle "PROP_FRAMEWORK" for set-property.
Definition: gsttensor_trainer.c:1143
TRAINER_EVENT_TRAINING_COMPLETION
@ TRAINER_EVENT_TRAINING_COMPLETION
Definition: nnstreamer_plugin_api_trainer.h:70
_GstTensorTrainer::notifier
GstTensorTrainerEventNotifier notifier
Definition: gsttensor_trainer.h:67
PROP_NUM_VALIDATION_SAMPLES
@ PROP_NUM_VALIDATION_SAMPLES
Definition: gsttensor_trainer.c:107
PROP_EPOCHS
@ PROP_EPOCHS
Definition: gsttensor_trainer.c:108
gst_tensor_trainer_stop_model_training
static void gst_tensor_trainer_stop_model_training(GstTensorTrainer *trainer)
Stop model training.
Definition: gsttensor_trainer.c:1328
_GstTensorTrainer::privateData
void * privateData
Definition: gsttensor_trainer.h:64
gst_tensor_trainer_class_init
static void gst_tensor_trainer_class_init(GstTensorTrainerClass *klass)
initialize the tensor_trainer's class
Definition: gsttensor_trainer.c:155
gst_tensor_trainer_init
static void gst_tensor_trainer_init(GstTensorTrainer *trainer)
Initialize tensor_trainer.
Definition: gsttensor_trainer.c:262
data
svtc_1 data
Definition: gsttensor_if.c:826
GstTensorInfo
Internal data structure for tensor info.
Definition: tensor_typedef.h:261
gst_tensor_trainer_start_model_training
static void gst_tensor_trainer_start_model_training(GstTensorTrainer *trainer)
Start model training.
Definition: gsttensor_trainer.c:1308
NNS_TENSOR_SIZE_LIMIT
#define NNS_TENSOR_SIZE_LIMIT
The number of tensors NNStreamer supports is 256. The max memories of gst-buffer is 16 (See NNS_TENSO...
Definition: tensor_typedef.h:42
gst_tensor_trainer_set_model_save_path
static void gst_tensor_trainer_set_model_save_path(GstTensorTrainer *trainer, const GValue *value)
Handle "PROP_MODEL_SAVE_PATH" for set-property.
Definition: gsttensor_trainer.c:1170
_GstTensorTrainer::is_training_complete
gboolean is_training_complete
Definition: gsttensor_trainer.h:53
src_template
static GstStaticPadTemplate src_template
The capabilities of the src pad.
Definition: gsttensor_trainer.c:58
FALSE
return FALSE
Definition: gsttensor_transform.c:596
result
case tensor_data_s gboolean * result
Definition: gsttensor_if.c:821
_GstTensorTrainerFramework::stop
int(* stop)(const GstTensorTrainerFramework *self, const GstTensorTrainerProperties *prop, void **private_data)
Definition: nnstreamer_plugin_api_trainer.h:128
nnstreamer_subplugin.h
Subplugin Manager for NNStreamer.
PROP_NUM_LABELS
@ PROP_NUM_LABELS
Definition: gsttensor_trainer.c:105
sink_template
static GstStaticPadTemplate sink_template
The capabilities of the sink pad.
Definition: gsttensor_trainer.c:50
PROP_MODEL_LOAD_PATH
@ PROP_MODEL_LOAD_PATH
Definition: gsttensor_trainer.c:103
PROP_0
@ PROP_0
Definition: gsttensor_trainer.c:99
GstTensorMemory::data
void * data
Definition: tensor_typedef.h:254
_GstTensorTrainer::dummy_data_thread
GThread * dummy_data_thread
Definition: gsttensor_trainer.h:74
GstTensorsInfo
Internal meta data exchange format for a other/tensors instance.
Definition: tensor_typedef.h:273
gst_tensor_trainer_sink_event
static gboolean gst_tensor_trainer_sink_event(GstPad *sinkpad, GstObject *parent, GstEvent *event)
Event handler for sink pad of tensor_trainer.
Definition: gsttensor_trainer.c:977
DEFAULT_PROP_LABEL_LIST
#define DEFAULT_PROP_LABEL_LIST
Definition: gsttensor_trainer.c:85
GST_DEBUG_CATEGORY_STATIC
GST_DEBUG_CATEGORY_STATIC(gst_tensor_trainer_debug)
DEFAULT_STR_PROP_VALUE
#define DEFAULT_STR_PROP_VALUE
Default string property value.
Definition: gsttensor_trainer.c:92
TRAINING_ACCURACY
@ TRAINING_ACCURACY
Definition: gsttensor_trainer.c:75
VALIDATION_LOSS
@ VALIDATION_LOSS
Definition: gsttensor_trainer.c:76
prop
GstTensorSrcIIOChannelProperties * prop
DTYPE_UNSIGNED ( .
Definition: gsttensor_srciio.c:110
_GstTensorTrainerProperties::training_loss
double training_loss
Definition: nnstreamer_plugin_api_trainer.h:44
gst_tensor_trainer_set_output_meta
static void gst_tensor_trainer_set_output_meta(GstTensorTrainer *trainer)
initialize the output tensor dimension
Definition: gsttensor_trainer.c:1347
_GstTensorTrainer::in_config
GstTensorsConfig in_config
Definition: gsttensor_trainer.h:59
_GstTensorTrainer::is_epoch_complete
gboolean is_epoch_complete
Definition: gsttensor_trainer.h:54
_GstTensorTrainerProperties::num_labels
unsigned int num_labels
Definition: nnstreamer_plugin_api_trainer.h:38
MODEL_STATS_SIZE
#define MODEL_STATS_SIZE
Definition: gsttensor_trainer.c:79
_GstTensorTrainerProperties::validation_loss
double validation_loss
Definition: nnstreamer_plugin_api_trainer.h:46
_GstTensorTrainer::training_completion_lock
GMutex training_completion_lock
Definition: gsttensor_trainer.h:69
gst_tensor_pad_caps_from_config
GstCaps * gst_tensor_pad_caps_from_config(GstPad *pad, const GstTensorsConfig *config)
Get pad caps from tensors config and caps of the peer connected to the pad.
Definition: nnstreamer_plugin_api_impl.c:1209
gst_tensors_info_init
void gst_tensors_info_init(GstTensorsInfo *info)
Initialize the tensors info structure.
Definition: nnstreamer_plugin_api_util_impl.c:325
_GstTensorTrainer::epoch_completion_lock
GMutex epoch_completion_lock
Definition: gsttensor_trainer.h:71
GstTensorTrainerEventType
GstTensorTrainerEventType
GstTensorTrainer's event type list.
Definition: nnstreamer_plugin_api_trainer.h:67
GstTensorMetaInfo
Data structure to describe a tensor data. This represents the basic information of a memory block for...
Definition: tensor_typedef.h:310
gst_tensor_trainer_get_tensor_size
static gsize gst_tensor_trainer_get_tensor_size(GstTensorTrainer *trainer, guint index, gboolean is_input)
Calculate tensor buffer size.
Definition: gsttensor_trainer.c:1250
GstTensorsConfig::rate_d
int rate_d
Definition: tensor_typedef.h:288
_GstTensorTrainerProperties::model_load_path
const char * model_load_path
Definition: nnstreamer_plugin_api_trainer.h:36
NNS_SUBPLUGIN_TRAINER
@ NNS_SUBPLUGIN_TRAINER
Definition: nnstreamer_subplugin.h:45
TRAINER_EVENT_EPOCH_COMPLETION
@ TRAINER_EVENT_EPOCH_COMPLETION
Definition: nnstreamer_plugin_api_trainer.h:69
_GstTensorTrainerProperties::training_accuracy
double training_accuracy
Definition: nnstreamer_plugin_api_trainer.h:45
gst_tensor_trainer_get_model_stats
static gboolean gst_tensor_trainer_get_model_stats(GstTensorTrainer *trainer, double *model_stats)
Get the model statistics from the sub-plugin.
Definition: gsttensor_trainer.c:769
gst_tensor_pad_caps_is_flexible
#define gst_tensor_pad_caps_is_flexible(p)
Macro to check current pad caps is flexible tensor.
Definition: tensor_common.h:231
_GstTensorTrainer
GstTensorTrainer data structure.
Definition: gsttensor_trainer.h:43
gst_tensor_trainer_change_state
static GstStateChangeReturn gst_tensor_trainer_change_state(GstElement *element, GstStateChange transition)
Change state of tensor_trainsink.
Definition: gsttensor_trainer.c:513
DEFAULT_PROP_VALID_SAMPLES
#define DEFAULT_PROP_VALID_SAMPLES
Definition: gsttensor_trainer.c:87
g_free
g_free(self->option[(opnum) - 1])
opnum: \
_GstTensorTrainer::epoch_completion_cond
GCond epoch_completion_cond
Definition: gsttensor_trainer.h:72
gst_tensor_trainer_epochs_is_complete
static gboolean gst_tensor_trainer_epochs_is_complete(GstTensorTrainer *trainer)
Check if current epochs is complete, tensor_trainer wait for one of epochs to complete before getting...
Definition: gsttensor_trainer.c:609
g_value_set_string
g_value_set_string(value, self->option[opnum - 1])
opnum: \
gst_tensor_trainer_src_query
static gboolean gst_tensor_trainer_src_query(GstPad *srcpad, GstObject *parent, GstQuery *query)
This function handles src pad query.
Definition: gsttensor_trainer.c:1109
gst_tensor_trainer_wait_for_training_completion
static void gst_tensor_trainer_wait_for_training_completion(GstTensorTrainer *trainer)
Wait for training completion.
Definition: gsttensor_trainer.c:955
_GstTensorTrainerProperties::validation_accuracy
double validation_accuracy
Definition: nnstreamer_plugin_api_trainer.h:47
gst_tensor_meta_info_parse_header
gboolean gst_tensor_meta_info_parse_header(GstTensorMetaInfo *meta, gpointer header)
Parse header and fill the tensor meta.
Definition: nnstreamer_plugin_api_util_impl.c:1527
PROP_NUM_TRAINING_SAMPLES
@ PROP_NUM_TRAINING_SAMPLES
Definition: gsttensor_trainer.c:106
gst_tensor_trainer_dummy_data_generation_func
static gpointer gst_tensor_trainer_dummy_data_generation_func(GstTensorTrainer *trainer)
Dummy data generation thread.
Definition: gsttensor_trainer.c:471
GstTensorsConfig::rate_n
int rate_n
Definition: tensor_typedef.h:287
_GstTensorTrainer::out_config
GstTensorsConfig out_config
Definition: gsttensor_trainer.h:58
_GstTensorTrainerFramework::start
int(* start)(const GstTensorTrainerFramework *self, const GstTensorTrainerProperties *prop, GstTensorTrainerEventNotifier *notifier, void *private_data)
Definition: nnstreamer_plugin_api_trainer.h:118
_GstTensorTrainer::training_completion_cond
GCond training_completion_cond
Definition: gsttensor_trainer.h:70
PROP_FRAMEWORK
@ PROP_FRAMEWORK
Definition: gsttensor_trainer.c:100
_GstTensorTrainerProperties::num_inputs
unsigned int num_inputs
Definition: nnstreamer_plugin_api_trainer.h:37
GstTensorMemory::size
size_t size
Definition: tensor_typedef.h:255
gst_tensor_trainer_sink_query
static gboolean gst_tensor_trainer_sink_query(GstPad *sinkpad, GstObject *parent, GstQuery *query)
This function handles sink pad query.
Definition: gsttensor_trainer.c:1046
_GstTensorTrainer::required_sample
guint required_sample
Definition: gsttensor_trainer.h:61
gst_tensors_config_free
void gst_tensors_config_free(GstTensorsConfig *config)
Free allocated data in tensors config structure.
Definition: nnstreamer_plugin_api_util_impl.c:845
register_subplugin
gboolean register_subplugin(subpluginType type, const char *name, const void *data)
Public function defined in the header.
Definition: nnstreamer_subplugin.c:225
_GstTensorTrainerFramework
tensor_trainer subplugin definition
Definition: nnstreamer_plugin_api_trainer.h:95
nnstreamer_trainer_notify_event
void nnstreamer_trainer_notify_event(GstTensorTrainerEventNotifier *notifier, GstTensorTrainerEventType type, void *data)
Trainer's sub-plugin may call this to send event.
Definition: gsttensor_trainer.c:1425
_GstTensorTrainer::output_meta
GstTensorsInfo output_meta
Definition: gsttensor_trainer.h:57
_GstTensorTrainerProperties::num_validation_samples
unsigned int num_validation_samples
Definition: nnstreamer_plugin_api_trainer.h:40
DEFAULT_PROP_EPOCHS
#define DEFAULT_PROP_EPOCHS
Definition: gsttensor_trainer.c:88
_GstTensorTrainer::input_tensors
GstTensorMemory input_tensors[NNS_TENSOR_SIZE_LIMIT]
Definition: gsttensor_trainer.h:56
gst_tensor_trainer_set_property
static void gst_tensor_trainer_set_property(GObject *object, guint prop_id, const GValue *value, GParamSpec *pspec)
Setter for tensor_trainsink properties.
Definition: gsttensor_trainer.c:350
_GstTensorTrainerProperties::num_training_samples
unsigned int num_training_samples
Definition: nnstreamer_plugin_api_trainer.h:39
gst_tensor_trainer_set_model_load_path
static void gst_tensor_trainer_set_model_load_path(GstTensorTrainer *trainer, const GValue *value)
Handle "PROP_MODEL_LOAD_PATH" for set-property.
Definition: gsttensor_trainer.c:1183
G_DEFINE_TYPE
G_DEFINE_TYPE(GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT)
gst_tensor_meta_info_convert
gboolean gst_tensor_meta_info_convert(GstTensorMetaInfo *meta, GstTensorInfo *info)
Convert GstTensorMetaInfo structure to GstTensorInfo.
Definition: nnstreamer_plugin_api_util_impl.c:1562
GstTensorsConfig
Internal data structure for configured tensors info (for other/tensors).
Definition: tensor_typedef.h:284
_GstTensorTrainerProperties::input_meta
GstTensorsInfo input_meta
Definition: nnstreamer_plugin_api_trainer.h:33
_GstTensorTrainer::fw_name
gchar * fw_name
Definition: gsttensor_trainer.h:50
VALIDATION_ACCURACY
@ VALIDATION_ACCURACY
Definition: gsttensor_trainer.c:77
_GstTensorTrainerFramework::getStatus
int(* getStatus)(const GstTensorTrainerFramework *self, GstTensorTrainerProperties *prop, void *private_data)
Definition: nnstreamer_plugin_api_trainer.h:146
GST_TENSOR_TRAINER
#define GST_TENSOR_TRAINER(obj)
Definition: gsttensor_trainer.h:28
_GstTensorTrainer::sinkpad
GstPad * sinkpad
Definition: gsttensor_trainer.h:47
TRUE
return TRUE
Definition: gsttensor_if.c:879
UNUSED
#define UNUSED(expr)
Definition: mqttcommon.h:19
nnstreamer_util.h
Optional NNStreamer utility functions for sub-plugin writers and users.
_GstTensorTrainer::fw_created
gboolean fw_created
Definition: gsttensor_trainer.h:52
gst_tensor_pad_possible_caps_from_config
GstCaps * gst_tensor_pad_possible_caps_from_config(GstPad *pad, const GstTensorsConfig *config)
Get all possible caps from tensors config. Unlike gst_tensor_pad_caps_from_config(),...
Definition: nnstreamer_plugin_api_impl.c:1286
gst_tensor_trainer_create_event_notifier
static void gst_tensor_trainer_create_event_notifier(GstTensorTrainer *trainer)
Create a event notifier.
Definition: gsttensor_trainer.c:1296
_GstTensorTrainerClass
GstTensorTrainerClass data structure.
Definition: gsttensor_trainer.h:80
nnstreamer_trainer_probe
int nnstreamer_trainer_probe(GstTensorTrainerFramework *ttsp)
Trainer's sub-plugin should call this function to register itself.
Definition: gsttensor_trainer.c:1371
DEFAULT_PROP_INPUT_LIST
#define DEFAULT_PROP_INPUT_LIST
Default framework property value.
Definition: gsttensor_trainer.c:84
_GstTensorTrainerFrameworkInfo
GstTensorTrainer's subplugin framework related information.
Definition: nnstreamer_plugin_api_trainer.h:55
_GstTensorTrainerProperties::model_save_path
const char * model_save_path
Definition: nnstreamer_plugin_api_trainer.h:35
PROP_MODEL_SAVE_PATH
@ PROP_MODEL_SAVE_PATH
Definition: gsttensor_trainer.c:102
gst_tensors_info_get_nth_info
GstTensorInfo * gst_tensors_info_get_nth_info(GstTensorsInfo *info, guint index)
Get the pointer of nth tensor information.
Definition: nnstreamer_plugin_api_util_impl.c:296
_GstTensorTrainerProperties::num_epochs
unsigned int num_epochs
Definition: nnstreamer_plugin_api_trainer.h:41
_GstTensorTrainerProperties::model_config
const char * model_config
Definition: nnstreamer_plugin_api_trainer.h:34
gst_tensors_info_get_size
gsize gst_tensors_info_get_size(const GstTensorsInfo *info, gint index)
Get data size of single tensor.
Definition: nnstreamer_plugin_api_util_impl.c:376
_GstTensorTrainer::cur_epoch_data_cnt
guint cur_epoch_data_cnt
Definition: gsttensor_trainer.h:62
gst_tensor_trainer_create_output
static GstBuffer * gst_tensor_trainer_create_output(GstTensorTrainer *trainer)
Create output tensors.
Definition: gsttensor_trainer.c:805
_GstTensorTrainerFramework::create
int(* create)(const GstTensorTrainerFramework *self, const GstTensorTrainerProperties *prop, void **private_data)
Definition: nnstreamer_plugin_api_trainer.h:102
nnstreamer_trainer_exit
int nnstreamer_trainer_exit(GstTensorTrainerFramework *ttsp)
Trainer's sub-plugin may call this to unregister itself.
Definition: gsttensor_trainer.c:1398
_GstTensorTrainer::srcpad
GstPad * srcpad
Definition: gsttensor_trainer.h:48
gst_tensor_trainer_get_property
static void gst_tensor_trainer_get_property(GObject *object, guint prop_id, GValue *value, GParamSpec *pspec)
Getter tensor_trainsink properties.
Definition: gsttensor_trainer.c:395
_GstTensorTrainerFramework::getFrameworkInfo
int(* getFrameworkInfo)(const GstTensorTrainerFramework *self, const GstTensorTrainerProperties *prop, void *private_data, GstTensorTrainerFrameworkInfo *fw_info)
Definition: nnstreamer_plugin_api_trainer.h:154
_GstTensorTrainerFramework::destroy
int(* destroy)(const GstTensorTrainerFramework *self, const GstTensorTrainerProperties *prop, void **private_data)
Definition: nnstreamer_plugin_api_trainer.h:110
_GstTensorTrainerProperties::epoch_count
unsigned int epoch_count
Definition: nnstreamer_plugin_api_trainer.h:43
GST_CAT_DEFAULT
#define GST_CAT_DEFAULT
Definition: gsttensor_trainer.c:64
GstTensorsInfo::num_tensors
unsigned int num_tensors
Definition: tensor_typedef.h:275
gst_tensor_trainer_create_model
static gboolean gst_tensor_trainer_create_model(GstTensorTrainer *trainer)
Create model.
Definition: gsttensor_trainer.c:1273
_NNS_FLOAT64
@ _NNS_FLOAT64
Definition: tensor_typedef.h:146
gst_tensor_trainer_wait_for_epoch_completion
static void gst_tensor_trainer_wait_for_epoch_completion(GstTensorTrainer *trainer)
Wait for epoch eompletion.
Definition: gsttensor_trainer.c:590
gst_tensor_buffer_get_nth_memory
GstMemory * gst_tensor_buffer_get_nth_memory(GstBuffer *buffer, const guint index)
Get the nth GstMemory from given buffer.
Definition: nnstreamer_plugin_api_impl.c:1608
gst_tensor_trainer_check_chain_conditions
static gboolean gst_tensor_trainer_check_chain_conditions(GstTensorTrainer *trainer, guint num_tensors)
Check chain conditions. If all conditions are met, proceed to next step.
Definition: gsttensor_trainer.c:645
PROP_MODEL_CONFIG
@ PROP_MODEL_CONFIG
Definition: gsttensor_trainer.c:101
DEFAULT_PROP_TRAIN_SAMPLES
#define DEFAULT_PROP_TRAIN_SAMPLES
Definition: gsttensor_trainer.c:86
gst_tensor_trainer_set_prop_model_config_file_path
static void gst_tensor_trainer_set_prop_model_config_file_path(GstTensorTrainer *trainer, const GValue *value)
Handle "PROP_MODEL_CONFIG" for set-property.
Definition: gsttensor_trainer.c:1157
gst_tensor_trainer_check_invalid_param
static gboolean gst_tensor_trainer_check_invalid_param(GstTensorTrainer *trainer)
Check invalid param.
Definition: gsttensor_trainer.c:440
gst_tensors_config_init
void gst_tensors_config_init(GstTensorsConfig *config)
Initialize the tensors config info structure (for other/tensors)
Definition: nnstreamer_plugin_api_util_impl.c:830
gst_tensor_buffer_get_count
guint gst_tensor_buffer_get_count(GstBuffer *buffer)
Get the number of tensors in the buffer.
Definition: nnstreamer_plugin_api_impl.c:1835
_GstTensorTrainerEventNotifier::notifier
void * notifier
Definition: nnstreamer_plugin_api_trainer.h:82
TRAINING_LOSS
@ TRAINING_LOSS
Definition: gsttensor_trainer.c:74
gsttensor_trainer.h
GStreamer plugin to train tensor data using NN Frameworks.
GstTensorInfo::type
tensor_type type
Definition: tensor_typedef.h:266
GstTensorsConfig::info
GstTensorsInfo info
Definition: tensor_typedef.h:286
gst_tensors_config_validate
gboolean gst_tensors_config_validate(const GstTensorsConfig *config)
Check the tensors are all configured (for other/tensors)
Definition: nnstreamer_plugin_api_util_impl.c:858
gst_tensor_trainer_push_input
static gboolean gst_tensor_trainer_push_input(GstTensorTrainer *trainer, GstBuffer *inbuf, gboolean in_flexible)
Create input tensors from the buffer and push it into trainer fw.
Definition: gsttensor_trainer.c:687
PROP_NUM_INPUTS
@ PROP_NUM_INPUTS
Definition: gsttensor_trainer.c:104
GST_IS_TENSOR_TRAINER
#define GST_IS_TENSOR_TRAINER(obj)
Definition: gsttensor_trainer.h:32
gst_tensors_info_copy
void gst_tensors_info_copy(GstTensorsInfo *dest, const GstTensorsInfo *src)
Copy tensor info.
Definition: nnstreamer_plugin_api_util_impl.c:502
GstTensorInfo::dimension
tensor_dim dimension
Definition: tensor_typedef.h:267
_GstTensorTrainerFramework::push_data
int(* push_data)(const GstTensorTrainerFramework *self, const GstTensorTrainerProperties *prop, void *private_data, const GstTensorMemory *input)
Definition: nnstreamer_plugin_api_trainer.h:136
gst_tensor_trainer_chain
static GstFlowReturn gst_tensor_trainer_chain(GstPad *sinkpad, GstObject *parent, GstBuffer *inbuf)
Chain function, this function does the actual processing.
Definition: gsttensor_trainer.c:868
gst_tensors_config_from_structure
gboolean gst_tensors_config_from_structure(GstTensorsConfig *config, const GstStructure *structure)
Parse structure and set tensors config (for other/tensors)
Definition: nnstreamer_plugin_api_impl.c:1413
type
svtc_1 type
Definition: gsttensor_if.c:825
_GstTensorTrainerEventNotifier
GstTensorTrainer's event notifier.
Definition: nnstreamer_plugin_api_trainer.h:79
gst_tensor_buffer_append_memory
gboolean gst_tensor_buffer_append_memory(GstBuffer *buffer, GstMemory *memory, const GstTensorInfo *info)
Append memory to given buffer.
Definition: nnstreamer_plugin_api_impl.c:1688
_GstTensorTrainer::prop
GstTensorTrainerProperties prop
Definition: gsttensor_trainer.h:66
_GstTensorTrainerFrameworkInfo::name
const char * name
Definition: nnstreamer_plugin_api_trainer.h:57
gst_tensor_trainer_query_caps
static GstCaps * gst_tensor_trainer_query_caps(GstTensorTrainer *trainer, GstPad *pad, GstCaps *filter)
Get pad caps for caps negotiation.
Definition: gsttensor_trainer.c:919
gst_tensor_trainer_convert_meta
static gsize gst_tensor_trainer_convert_meta(GstTensorTrainer *trainer, GstTensorMetaInfo *meta, GstTensorInfo *info, void *data)
Convert tensor meta and get the size of tensor header.
Definition: gsttensor_trainer.c:665
_GstTensorTrainer::fw
const GstTensorTrainerFramework * fw
Definition: gsttensor_trainer.h:65
gst_tensor_trainer_finalize
static void gst_tensor_trainer_finalize(GObject *object)
Function to finalize instance.
Definition: gsttensor_trainer.c:316
GST_ERROR
GST_ERROR("Failed to register nnstreamer plugin : tensor_" # name)
type)) { \
_GstTensorTrainerProperties
GstTensorTrainer's properties for neural network framework (internal data structure)
Definition: nnstreamer_plugin_api_trainer.h:31
SINK_CAPS_STRING
#define SINK_CAPS_STRING
Default caps string for sink.
Definition: gsttensor_trainer.c:40
get_subplugin
const void * get_subplugin(subpluginType type, const char *name)
Public function defined in the header.
Definition: nnstreamer_subplugin.c:141
gst_tensor_trainer_check_buffer_drop_conditions
static gboolean gst_tensor_trainer_check_buffer_drop_conditions(GstTensorTrainer *trainer)
Check buffer drop conditions. If condition is met, drop the buffer.
Definition: gsttensor_trainer.c:629
unregister_subplugin
gboolean unregister_subplugin(subpluginType type, const char *name)
Public function defined in the header.
Definition: nnstreamer_subplugin.c:289
gst_tensor_meta_info_get_header_size
gsize gst_tensor_meta_info_get_header_size(GstTensorMetaInfo *meta)
Get the header size to handle a tensor meta.
Definition: nnstreamer_plugin_api_util_impl.c:1456
SRC_CAPS_STRING
#define SRC_CAPS_STRING
Default caps string for src.
Definition: gsttensor_trainer.c:45
gst_tensor_trainer_find_framework
static gboolean gst_tensor_trainer_find_framework(GstTensorTrainer *trainer, const char *name)
Find Trainer sub-plugin with the name.
Definition: gsttensor_trainer.c:1196