added a GUI mode for stepping thru frames to help evaluate VQ
authordrowe67 <drowe67@01035d8c-6547-0410-b346-abe4f91aad63>
Thu, 6 Jul 2017 09:45:55 +0000 (09:45 +0000)
committerdrowe67 <drowe67@01035d8c-6547-0410-b346-abe4f91aad63>
Thu, 6 Jul 2017 09:45:55 +0000 (09:45 +0000)
git-svn-id: https://svn.code.sf.net/p/freetel/code@3279 01035d8c-6547-0410-b346-abe4f91aad63

codec2-dev/octave/kmeans_tests.m

index 24159ca9a3f23ca4f53ce435fe8aeefb6ae3ca4e..9b4d8837a33bb16c5f536c4bf324b8606c025a4c 100644 (file)
@@ -17,7 +17,7 @@
 %----------------------------------------------------------------------
 % standard mean squared error search
 
-function [idx contrib errors test_ g mg] = vq_search_mse(vq, data)
+function [idx contrib errors test_ g mg sl] = vq_search_mse(vq, data)
   [nVec nCols] = size(vq);
   nRows = rows(data);
   
@@ -38,14 +38,14 @@ function [idx contrib errors test_ g mg] = vq_search_mse(vq, data)
     test_(f,:) = vq(min_ind,:);
   end
 
-  g = mg = 1; % dummys for this function
+  g = 0; mg = 1; sl = 0; % dummys for this function
 endfunction
 
 
 %----------------------------------------------------------------------
 % abs() search with a linear gain term
 
-function [idx contrib errors test_ g mg] = vq_search_gain(vq, data)
+function [idx contrib errors test_ g mg sl] = vq_search_gain(vq, data)
   
   [nVec nCols] = size(vq);
   nRows = rows(data);
@@ -79,14 +79,14 @@ function [idx contrib errors test_ g mg] = vq_search_gain(vq, data)
     errors(f) = mn; 
     contrib(f,:) = test_(f,:) = vq(min_ind,:) + g(f,min_ind);
   end
-  mg = 1;
+  mg = 1; sl = 0;
 endfunction
 
 
 %----------------------------------------------------------------------
 % abs() search with a linear plus ampl scaling term
 
-function [idx contrib errors test_ g mg] = vq_search_mag(vq, data)
+function [idx contrib errors test_ g mg sl] = vq_search_mag(vq, data)
   [nVec nCols] = size(vq);
   nRows = rows(data);
   
@@ -128,7 +128,7 @@ function [idx contrib errors test_ g mg] = vq_search_mag(vq, data)
 
     contrib(f,:) = test_(f,:) = mg(f,min_ind) * vq(min_ind,:) + g(f,min_ind);
   end
-
+  sl = 0;
 endfunction
 
 
@@ -202,8 +202,8 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 0)
   
   [nRows nCols] = size(test);
 
-  [idx contrib errors test_ g mg] = feval(search_func, vq, test);
-  
+  [idx contrib errors test_ g mg sl] = feval(search_func, vq, test);
+
   % sd over time
  
   sd_per_frame = zeros(nRows,1);
@@ -211,44 +211,111 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 0)
     sd_per_frame(i) = mean(abs(test(i,:) - test_(i,:)));
   end
 
+  printf("average sd: %3.2f\n", mean(sd_per_frame));
   %printf("%18s nVec: %d SD: %3.2f dB\n", search_func, nVec, mean(sd_per_frame));
   
-  % plots sd and errors over time
+  % Optional GUI with U to anlayse VQ perf ---------------------------------------------------
 
   if gui_en
-    figure(1); clf; subplot(211); plot(sd_per_frame); title('SD'); subplot(212); plot(errors); title('mean error');
+    gui_mode = "sort";     % "sort" or "fbf"
+    sort_order = "worst";  % "worst" or "best"
+    f = 1;
+
+    % plots sd over time and histogram
+
+    figure(1); clf; subplot(211); plot(sd_per_frame); title('SD'); 
+    subplot(212); [yy xx] = hist(sd_per_frame); plot(xx,yy,'+-');
+  end
+
+  while gui_en
 
     % display m frames, printing some stats, plotting vector to give visual idea of match
 
     figure(2); clf;
-    [errors_dec frame_dec] = sort(errors, "descend");
-    m = 4;
+
+    % try this in 'ascend' or 'descend' mode to best or worst frames
+
+    if strcmp(gui_mode, "sort")
+      if strcmp(sort_order, "best")
+        [errors_dec frame_dec] = sort(errors, "ascend");
+      end
+      if strcmp(sort_order, "worst")
+        [errors_dec frame_dec] = sort(errors, "descend");
+      end
+       m = 4;
+    else
+      m = 1;
+      errors_dec = errors(f); frame_dec(1) = f;
+    end
+
     for i=1:m
+
+      % build up VQ legend
+
       af = frame_dec(i); aind = idx(af);
-      l = sprintf("idx: %d", aind);
-      ag = 0; amg = 1;
-      if strcmp(search_func, "vq_search_gain") || strcmp(search_func, "vq_search_mag")
+
+      ag = 0; amg = 1; asl = 0;
+      if strcmp(search_func, "vq_search_gain")
         ag = g(af,aind);
-        l = sprintf("%s g: %3.2f", l, ag);
+        l3 = sprintf("g: %3.2f", ag);
       end
       if strcmp(search_func, "vq_search_mag")
+        ag = g(af,aind);
+        amg = mg(af,aind);
+        l3  = sprintf("g: %3.2f mg: %3.2f", ag, amg);
+      end
+      if strcmp(search_func, "vq_search_slope")
+        ag = g(af,aind);
         amg = mg(af,aind);
-        l = sprintf("%s mg: %3.2f", l, amg);
+        asl = sl(af,aind);
+        l3 = sprintf("g: %3.2f mg: %3.2f sl: %3.2f", ag, amg, asl);
       end
-      %printf("%d f: %d %s\n", i, af, l);
 
-      subplot(sqrt(m),sqrt(m),i);
-      l1 = sprintf("b-;fr %d;", af);
+      % plot target 
+
+      subplot(sqrt(m), sqrt(m),i);
+      l1 = sprintf("b-;fr %d sd: %3.2f;", af, errors_dec(m));
       plot(test(af,:), l1); 
 
+      % plot vq vector and modified version
+
       hold on; 
       l2 = sprintf("g-+;ind %d;", aind); 
       plot(vq(aind, :), l2); 
-      l3 = sprintf("g-o;%s;",l);   
-      plot(amg*vq(aind, :) + ag, l3); 
+      l3 = sprintf("g-o;%s;",l3);
+      plot(amg*vq(aind, :) + ag + asl*(1:nCols), l3); 
       hold off;
       axis([1 nCols -10 40]);
     end
+
+    % interactive menu ------------------------------------------
+
+    if strcmp(gui_mode, "sort")
+      printf("\rmenu: m-mode[%s] o-order[%s] q-quit", gui_mode, sort_order);
+    end
+    if strcmp(gui_mode, "fbf")
+      printf("\rmenu: m-mode[%s] frame: %d  n-next  b-back  q-quit", gui_mode, f);
+    end
+    fflush(stdout);
+    k = kbhit();
+    
+    if k == 'm'
+      if strcmp(gui_mode, "sort")
+        gui_mode = "fbf"; f = frame_dec(1); 
+      else
+        gui_mode = "sort";
+      end;
+    end
+    if k == 'o'
+       if strcmp(sort_order, "worst")
+         sort_order = "best"; 
+       else
+         sort_order = "worst";
+      end;
+    end
+    if k == 'n', f = f + 1; endif;
+    if k == 'b', f = f - 1; endif;
+    if k == 'q', gui_en = 0; printf("\n"); endif;
   end  
   
 endfunction
@@ -277,10 +344,13 @@ function [sd des] = train_vq_and_run_tests(sim_in)
   Nvec = sim_in.Nvec;
 
   train_func = get_search_func(sim_in.train_func_short);
-  [idx vq] = kmeans(sim_in.trainvec, Nvec, 
-                    "start", "sample", 
-                    "emptyaction", "singleton",
-                    "search_func", train_func);
+  printf("    Nvec %d kmeans start\n", Nvec);
+  [idx vq] = feval(sim_in.kmeans,
+                   sim_in.trainvec, Nvec, 
+                   "start", "sample", 
+                   "emptyaction", "singleton",
+                   "search_func", train_func);
+  printf("    Nvec %d kmeans end\n", Nvec);
 
   sd = []; des = []; tests = sim_in.tests;
   for t=1:length(cellstr(tests))
@@ -323,7 +393,7 @@ function plot_sd_results(fg, title_str, sd, desc, train_list, test_list)
     for j=1:length(cellstr(test_list))
       test_func_short = char(cellstr(test_list)(j));
       [y x] = search_tests(sd, desc, train_func_short, test_func_short);
-      leg = sprintf("o-%d;train: %5s test: %5s;", nlines, train_func_short, test_func_short)
+      leg = sprintf("o-%d;train: %5s test: %5s;", nlines, train_func_short, test_func_short);
       if nlines, hold on; end;
       x += inc; inc += 0.1;            % separate x coords a bit to make errors bars legible
       errorbar(x, y(:,1), y(:,2), leg);
@@ -372,8 +442,15 @@ end
 
 %----------------------------------------------------------------------
 % Run a bunch of long tests in parallel and plot results
+%
+
+% This test tries a bunch of training and test search combinations on
+% the first 1kHz to see if there is any advantage in using the higher
+% order search techniques in the VQ training.  So far it appears
+% not....  However the higher order search routines do give great
+% results compared to a basic MSE (or oder) search.
 
-function [sd desc] = long_tests(quick_check=0)
+function [sd desc] = long_tests_1_10(quick_check=0)
   num_cores = 4;
   K = 10;
   load surf_train_120_hpf150; load surf_all_hpf150;
@@ -391,7 +468,8 @@ function [sd desc] = long_tests(quick_check=0)
 
   % build up a big array of tests to run ------------------------
 
-  sim_in.trainvec = trainvec; sim_in.testvec = testvec; 
+  sim_in.trainvec = trainvec; sim_in.testvec = testvec;
+  sim_in.kmeans = "kmeans2";  % slow but supports different search functions
   sim_in.Nvec = 64; 
   sim_in.train_func_short = "mse"; 
   sim_in.tests = ["mse"; "gain"; "mag"; "slope"];
@@ -421,7 +499,8 @@ function [sd desc] = long_tests(quick_check=0)
   % run test list in parallel
 
   [sd desc] =  pararrayfun(num_cores, @train_vq_and_run_tests, test_list);
-
+  %train_vq_and_run_tests(test_list(1));
+  
   % Plot results -----------------------------------------------
 
   fg = 1;
@@ -446,27 +525,98 @@ function [sd desc] = long_tests(quick_check=0)
 endfunction
 
 
-function short_detailed_test(train_func, test_func)
-  K = 10;
+%----------------------------------------------------------------------
+% This test tries some test combinations on the 1000 to 3000 Hz range.
+%
+% Based on our results with long_test_1_10() we just use MSE training
+
+function [sd desc] = long_tests_11_40(quick_check=0)
+  num_cores = 4;
+  K = 20; st = 11; en = 40;
+  load surf_train_120_hpf150; load surf_all_hpf150;
+
+  if quick_check
+    NtrainVec = 1000;
+    NtestVec = 100;
+  else
+    NtrainVec = length(surf_train_120_hpf150);
+    NtestVec = length(surf_all_hpf150);
+  end
+
+  trainvec = surf_train_120_hpf150(1:NtrainVec,st:en);
+  testvec = surf_all_hpf150(1:NtestVec,st:en);
+
+  % build up an array of tests to run ------------------------
+
+  sim_in.kmeans = "kmeans";
+  sim_in.trainvec = trainvec; sim_in.testvec = testvec; 
+  sim_in.Nvec = 64; 
+  sim_in.train_func_short = "mse"; 
+  sim_in.tests = ["slope"];
+
+  % Test1: mse training, 64, 128, 256, 512
+
+  sim_in_vec(1:4) = sim_in;
+  sim_in_vec(2).Nvec = 128; sim_in_vec(3).Nvec = 256; sim_in_vec(4).Nvec = 512;
+
+  test_list = sim_in_vec;
+
+  % run test list in parallel
+
+  [sd desc] =  pararrayfun(num_cores, @train_vq_and_run_tests, test_list);
+  %[sd desc] =  train_vq_and_run_tests(sim_in_vec(4));
+
+  % Plot results -----------------------------------------------
+
+  fg = 2;
+  plot_sd_results(fg++, "MSE Training 10..30", sd, desc, "mse", "slope");
+
+#{
+  % histogram of results from Nvec=64, all training methods, slope search
+
+  testnum = search_tests_Nvec(sd, desc, "mse", "slope", 64);
+  hist_sd = sd(:,testnum); hist_desc = desc(testnum);
+  testnum = search_tests_Nvec(sd, desc, "gain", "slope", 64);
+  hist_sd = [hist_sd sd(:,testnum)]; hist_desc = [hist_desc desc(testnum)];
+  testnum = search_tests_Nvec(sd, desc, "mag", "slope", 64);
+  hist_sd = [hist_sd sd(:,testnum)]; hist_desc = [hist_desc desc(testnum)];
+  testnum = search_tests_Nvec(sd, desc, "slope", "slope", 64);
+  hist_sd = [hist_sd sd(:,testnum)]; hist_desc = [hist_desc desc(testnum)];
+  compare_hist(fg, "Histogram of SDs Nvec=64", hist_sd, hist_desc);
+#}
+endfunction
+
+
+function vq = detailed_test(Nvec=64, st=1, en=10, quick_check=1, test_func = "vq_search_slope")
+  K = en-st+1;
   load surf_train_120_hpf150;
-  load surf_all;
-  NtrainVec = 1000;
-  NtestVec = 100;
-  trainvec = surf_train_120_hpf150(1:NtrainVec,1:K);
-  testvec = surf_all(1:NtestVec,1:K);
+  load surf_all_hpf150;
+  if quick_check
+    NtrainVec = 1000;
+    NtestVec = 100;
+  else
+    NtrainVec = length(surf_train_120_hpf150);
+    NtestVec = length(surf_all_hpf150);
+    NtestVec = 100;
+  end
 
-  Nvec = 64;  % we can plot all vectors on one screen of subplots
-  
-  [idx vq]  = kmeans2(trainvec, Nvec,
-                      "start", "sample", 
-                      "emptyaction", "singleton", 
-                      "search_func", train_func);
+  trainvec = surf_train_120_hpf150(1:NtrainVec,st:en);
+  testvec = surf_all_hpf150(1:NtestVec,st:en);
 
-  sdpf = run_test(vq, testvec, Nvec, test_func, gui_en=1);
+  [idx vq]  = kmeans(trainvec, Nvec,
+                     "start", "sample", 
+                     "emptyaction", "singleton");
+
+  run_test(vq, testvec, Nvec, test_func, gui_en=1);
 endfunction
 
 
-% Some contrived examples to test VQ training
+% ------------------------------------------------------------------
+%
+% Following test_* functions have some contrived examples to test VQ
+% training with each training method.  However the longtest_1_10
+% results suggest these training methds don't provide any improvement
+% over standard MSE kmeans.
 
 function test_training_mse
   K = 3; NtrainVec = 10; Nvec = 2;  
@@ -604,8 +754,11 @@ format; more off;
 rand('seed',1);    % kmeans using rand for initial population,
                    % we want same results on every run
 
-%short_detailed_test('vq_search_mag', 'vq_search_mag');
-[sd desc] = long_tests(quick_check=0);
+% choose one of these to run
+
+detailed_test(64, 1, 20, quick_check=1);
+%[sd desc] = long_tests_1_10(quick_check=1);
+%[sd desc] = long_tests_11_40(quick_check=0);
 %test_training_slope