refactoring kmeans again lol
authordrowe67 <drowe67@01035d8c-6547-0410-b346-abe4f91aad63>
Tue, 4 Jul 2017 01:31:16 +0000 (01:31 +0000)
committerdrowe67 <drowe67@01035d8c-6547-0410-b346-abe4f91aad63>
Tue, 4 Jul 2017 01:31:16 +0000 (01:31 +0000)
git-svn-id: https://svn.code.sf.net/p/freetel/code@3277 01035d8c-6547-0410-b346-abe4f91aad63

codec2-dev/octave/kmeans_tests.m

index 3efdf8032bfdc57776b39cbbf20ecc809b31f85a..c3e604f15566f465b25bf3d579dc5e1f3701128e 100644 (file)
@@ -14,6 +14,7 @@
 %
 %----------------------------------------------------------------------
 
+%----------------------------------------------------------------------
 % standard mean squared error search
 
 function [idx contrib errors test_ g mg] = vq_search_mse(vq, data)
@@ -41,6 +42,7 @@ function [idx contrib errors test_ g mg] = vq_search_mse(vq, data)
 endfunction
 
 
+%----------------------------------------------------------------------
 % abs() search with a linear gain term
 
 function [idx contrib errors test_ g mg] = vq_search_gain(vq, data)
@@ -81,6 +83,7 @@ function [idx contrib errors test_ g mg] = vq_search_gain(vq, data)
 endfunction
 
 
+%----------------------------------------------------------------------
 % abs() search with a linear plus ampl scaling term
 
 function [idx contrib errors test_ g mg] = vq_search_mag(vq, data)
@@ -132,10 +135,16 @@ function [idx contrib errors test_ g mg] = vq_search_mag(vq, data)
 endfunction
 
 
+%----------------------------------------------------------------------
+%
+% Functions to support simulation of different VQ training and testing
+%
+%----------------------------------------------------------------------
+
 % evaluate database test using vq, with selectable search function.  Can be operated in
 % GUI mode to analyse in fine detail or batch mode to evaluate lots of data.
 
-function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 1)
+function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 0)
 
   % Test VQ using test data -----------------------
 
@@ -155,7 +164,7 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 1)
     sd_per_frame(i) = mean(abs(test(i,:) - test_(i,:)));
   end
 
-  printf("%18s mean SD: %3.2f dB\n", search_func, mean(sd_per_frame));
+  %printf("%18s nVec: %d SD: %3.2f dB\n", search_func, nVec, mean(sd_per_frame));
   
   % plots sd and errors over time
 
@@ -198,40 +207,25 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 1)
 endfunction
 
 
-function compare_hist(atitle, sdpf_mse, sdpf_gain, sdpf_mag)
-  [mse_yy, mse_xx] = hist(sdpf_mse);
-  [gain_yy, gain_xx] = hist(sdpf_gain);
-  [mag_yy, mag_xx] = hist(sdpf_mag);
-
-  plot(mse_xx, mse_yy, 'b+-;mse;');
-  hold on;
-  plot(gain_xx, gain_yy, 'g+-;gain;');
-  plot(mag_xx, mag_yy, 'r+-;mag;');
-  hold off;
+%----------------------------------------------------------------------
+%
+% Plot histograms of SDs for comparison % % Each col of sd_per_frame
+% has the results of one test, number of cols % is number of tests.  leg
+% is a col vector with one legend string for each test.
+
+function compare_hist(atitle, sdpf, leg)
+  [nRows nCols] = size(sd);
+  for c=1:nCols
+    [yy, xx] = hist(sd(:, c));
+    if c == 2, hold on; end;
+    l = char(cellstr(leg)(c));
+    plot(xx, yy, leg(c));
+  end
+  if nCols > 1, hold off; end;
   title(atitle)
 end
 
 
-function sd = three_tests(sim_in)
-  trainvec = sim_in.trainvec;
-  testvec = sim_in.testvec;
-  Nvec = sim_in.Nvec;
-  train_func = sim_in.train_func;
-
-  printf("  Nvec: %d\n", Nvec);
-
-  [idx vq] = kmeans(trainvec, Nvec, 
-                    "start", "sample", 
-                    "emptyaction", "singleton",
-                    "search_func", train_func);
-
-  sd_mse  = run_test(vq, testvec, Nvec, 'vq_search_mse', gui_en=0);
-  sd_gain = run_test(vq, testvec, Nvec, 'vq_search_gain', gui_en=0);
-  sd_mag = run_test(vq, testvec, Nvec, 'vq_search_mag', gui_en=0);
-
-  sd = [sd_mse sd_gain sd_mag];
-endfunction
-
 function plot_sd_results(title_str, fg, offset, sd)
   figure(fg); clf;  
   samples = offset+[1 4 7];
@@ -251,13 +245,13 @@ endfunction
 
 function plot_sd_results2(title_str, fg, sd)
   figure(fg); clf;  
-  samples = 0+[3 6 9];
+  samples = 9+[3 6 9];
   bits = log2([64 128 256]);
   errorbar(bits-0.1, mean(sd(:,samples)), std(sd(:, samples),[]),'b+-; MSE train;');
   hold on;
-  samples = 9+[3 6 9];
-  errorbar(bits+0.0, mean(sd(:,samples)), std(sd(:,samples),[]),'g+-;Gain train;');
   samples = 18+[3 6 9];
+  errorbar(bits+0.0, mean(sd(:,samples)), std(sd(:,samples),[]),'g+-;Gain train;');
+  samples = 27+[3 6 9];
   errorbar(bits+0.1, mean(sd(:,samples)), std(sd(:,samples),[]),'r+-;Mag train;');
   hold off;
   xlabel('VQ size (bits)')
@@ -266,6 +260,45 @@ function plot_sd_results2(title_str, fg, sd)
 endfunction
 
 
+function search_func = get_search_func(short_name)
+  if strcmp(short_name, "mse")
+    search_func = 'vq_search_mse';
+  end
+  if strcmp(short_name, "gain")
+    search_func = 'vq_search_gain';
+  end
+  if strcmp(short_name, "mag")
+    search_func = 'vq_search_mag';
+  end
+end
+
+
+%----------------------------------------------------------------------
+% Train up a  VQ and run one or mode tests
+
+function [sd des] = train_vq_and_run_tests(sim_in)
+  Nvec = sim_in.Nvec;
+
+  train_func = get_search_func(sim_in.train_func_short);
+  [idx vq] = kmeans(sim_in.trainvec, Nvec, 
+                    "start", "sample", 
+                    "emptyaction", "singleton",
+                    "search_func", train_func);
+
+  sd = []; des = []; tests = sim_in.tests;
+  for t=1:length(cellstr(tests))
+    test_func_short = char(cellstr(tests)(t));
+    test_func = get_search_func(test_func_short);
+    asd = run_test(vq, sim_in.testvec, Nvec, test_func); 
+    sd = [sd asd];
+    ades = sprintf("Nvec: %3d train: %-4s test: %-4s", 
+           Nvec, sim_in.train_func_short, test_func_short); 
+    des = [des; ades];
+    printf("    %s SD: %3.2f\n", ades, mean(asd));
+  end
+
+endfunction
+
 function sd = long_tests(quick_check=0)
   num_cores = 4;
   K = 10;
@@ -285,9 +318,18 @@ function sd = long_tests(quick_check=0)
   trainvec_hpf150 = surf_train_120_hpf150(1:NtrainVec,1:K);
   testvec_hpf150 = surf_all_hpf150(1:NtestVec,1:K);
 
-  sim_in.trainvec = trainvec; sim_in.testvec = testvec; sim_in.Nvec = 64; sim_in.train_func = 'vq_search_mse';
+  sim_in.trainvec = trainvec; sim_in.testvec = testvec; 
+  sim_in.Nvec = 64; 
+  sim_in.train_func_short = "mse"; 
+  sim_in.tests = ["mse"; "gain"; "mag"];
   sim_in_vec(1:3) = sim_in;
   sim_in_vec(2).Nvec = 128; sim_in_vec(3).Nvec = 256;
+
+  [sd desc] = train_vq_and_run_tests(sim_in(1));
+  
+  desc
+
+#{
   sd =  pararrayfun(num_cores, @three_tests, sim_in_vec);
   plot_sd_results("MSE Training", 1, 0, sd)
 
@@ -297,15 +339,16 @@ function sd = long_tests(quick_check=0)
 
   for i=1:3, sim_in_vec(i).train_func = 'vq_search_gain'; end;
   sd = [sd pararrayfun(num_cores, @three_tests, sim_in_vec)];
-  plot_sd_results("Gain training 150Hz HPF", 3, 9, sd)
+  plot_sd_results("Gain training 150Hz HPF", 3, 18, sd)
 
   for i=1:3, sim_in_vec(i).train_func = 'vq_search_mse'; end;
   sd = [sd pararrayfun(num_cores, @three_tests, sim_in_vec)];
-  plot_sd_results("Mag training 150Hz HPF", 4, 9, sd)
+  plot_sd_results("Mag training 150Hz HPF", 4, 27, sd)
 
   plot_sd_results2("Mag search 150Hz HPF", 5, sd)
   
-  figure(6); clf; compare_hist("Mag train Nvec=64", sd(:,3), sd(:,12), sd(:,21));
+  figure(6); clf; compare_hist("Mag search Nvec=64", sd(:,3), sd(:,12), sd(:,21));
+#}
 endfunction
 
 
@@ -431,7 +474,7 @@ rand('seed',1);    % kmeans using rand for initial population,
                    % we want same results on every run
 
 %short_detailed_test('vq_search_mag', 'vq_search_mag');
-sd = long_tests(quick_check=0);
+sd = long_tests(quick_check=1);
 %test_training_mag