From: drowe67 Date: Tue, 4 Jul 2017 01:31:16 +0000 (+0000) Subject: refactoring kmeans again lol X-Git-Url: http://git.whiteaudio.com/gitweb/?a=commitdiff_plain;h=b2cfc211b41baadbd9f74fdf5a51ce1d995bbe84;p=freetel-svn-tracking.git refactoring kmeans again lol git-svn-id: https://svn.code.sf.net/p/freetel/code@3277 01035d8c-6547-0410-b346-abe4f91aad63 --- diff --git a/codec2-dev/octave/kmeans_tests.m b/codec2-dev/octave/kmeans_tests.m index 3efdf803..c3e604f1 100644 --- a/codec2-dev/octave/kmeans_tests.m +++ b/codec2-dev/octave/kmeans_tests.m @@ -14,6 +14,7 @@ % %---------------------------------------------------------------------- +%---------------------------------------------------------------------- % standard mean squared error search function [idx contrib errors test_ g mg] = vq_search_mse(vq, data) @@ -41,6 +42,7 @@ function [idx contrib errors test_ g mg] = vq_search_mse(vq, data) endfunction +%---------------------------------------------------------------------- % abs() search with a linear gain term function [idx contrib errors test_ g mg] = vq_search_gain(vq, data) @@ -81,6 +83,7 @@ function [idx contrib errors test_ g mg] = vq_search_gain(vq, data) endfunction +%---------------------------------------------------------------------- % abs() search with a linear plus ampl scaling term function [idx contrib errors test_ g mg] = vq_search_mag(vq, data) @@ -132,10 +135,16 @@ function [idx contrib errors test_ g mg] = vq_search_mag(vq, data) endfunction +%---------------------------------------------------------------------- +% +% Functions to support simulation of different VQ training and testing +% +%---------------------------------------------------------------------- + % evaluate database test using vq, with selectable search function. Can be operated in % GUI mode to analyse in fine detail or batch mode to evaluate lots of data. -function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 1) +function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 0) % Test VQ using test data ----------------------- @@ -155,7 +164,7 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 1) sd_per_frame(i) = mean(abs(test(i,:) - test_(i,:))); end - printf("%18s mean SD: %3.2f dB\n", search_func, mean(sd_per_frame)); + %printf("%18s nVec: %d SD: %3.2f dB\n", search_func, nVec, mean(sd_per_frame)); % plots sd and errors over time @@ -198,40 +207,25 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 1) endfunction -function compare_hist(atitle, sdpf_mse, sdpf_gain, sdpf_mag) - [mse_yy, mse_xx] = hist(sdpf_mse); - [gain_yy, gain_xx] = hist(sdpf_gain); - [mag_yy, mag_xx] = hist(sdpf_mag); - - plot(mse_xx, mse_yy, 'b+-;mse;'); - hold on; - plot(gain_xx, gain_yy, 'g+-;gain;'); - plot(mag_xx, mag_yy, 'r+-;mag;'); - hold off; +%---------------------------------------------------------------------- +% +% Plot histograms of SDs for comparison % % Each col of sd_per_frame +% has the results of one test, number of cols % is number of tests. leg +% is a col vector with one legend string for each test. + +function compare_hist(atitle, sdpf, leg) + [nRows nCols] = size(sd); + for c=1:nCols + [yy, xx] = hist(sd(:, c)); + if c == 2, hold on; end; + l = char(cellstr(leg)(c)); + plot(xx, yy, leg(c)); + end + if nCols > 1, hold off; end; title(atitle) end -function sd = three_tests(sim_in) - trainvec = sim_in.trainvec; - testvec = sim_in.testvec; - Nvec = sim_in.Nvec; - train_func = sim_in.train_func; - - printf(" Nvec: %d\n", Nvec); - - [idx vq] = kmeans(trainvec, Nvec, - "start", "sample", - "emptyaction", "singleton", - "search_func", train_func); - - sd_mse = run_test(vq, testvec, Nvec, 'vq_search_mse', gui_en=0); - sd_gain = run_test(vq, testvec, Nvec, 'vq_search_gain', gui_en=0); - sd_mag = run_test(vq, testvec, Nvec, 'vq_search_mag', gui_en=0); - - sd = [sd_mse sd_gain sd_mag]; -endfunction - function plot_sd_results(title_str, fg, offset, sd) figure(fg); clf; samples = offset+[1 4 7]; @@ -251,13 +245,13 @@ endfunction function plot_sd_results2(title_str, fg, sd) figure(fg); clf; - samples = 0+[3 6 9]; + samples = 9+[3 6 9]; bits = log2([64 128 256]); errorbar(bits-0.1, mean(sd(:,samples)), std(sd(:, samples),[]),'b+-; MSE train;'); hold on; - samples = 9+[3 6 9]; - errorbar(bits+0.0, mean(sd(:,samples)), std(sd(:,samples),[]),'g+-;Gain train;'); samples = 18+[3 6 9]; + errorbar(bits+0.0, mean(sd(:,samples)), std(sd(:,samples),[]),'g+-;Gain train;'); + samples = 27+[3 6 9]; errorbar(bits+0.1, mean(sd(:,samples)), std(sd(:,samples),[]),'r+-;Mag train;'); hold off; xlabel('VQ size (bits)') @@ -266,6 +260,45 @@ function plot_sd_results2(title_str, fg, sd) endfunction +function search_func = get_search_func(short_name) + if strcmp(short_name, "mse") + search_func = 'vq_search_mse'; + end + if strcmp(short_name, "gain") + search_func = 'vq_search_gain'; + end + if strcmp(short_name, "mag") + search_func = 'vq_search_mag'; + end +end + + +%---------------------------------------------------------------------- +% Train up a VQ and run one or mode tests + +function [sd des] = train_vq_and_run_tests(sim_in) + Nvec = sim_in.Nvec; + + train_func = get_search_func(sim_in.train_func_short); + [idx vq] = kmeans(sim_in.trainvec, Nvec, + "start", "sample", + "emptyaction", "singleton", + "search_func", train_func); + + sd = []; des = []; tests = sim_in.tests; + for t=1:length(cellstr(tests)) + test_func_short = char(cellstr(tests)(t)); + test_func = get_search_func(test_func_short); + asd = run_test(vq, sim_in.testvec, Nvec, test_func); + sd = [sd asd]; + ades = sprintf("Nvec: %3d train: %-4s test: %-4s", + Nvec, sim_in.train_func_short, test_func_short); + des = [des; ades]; + printf(" %s SD: %3.2f\n", ades, mean(asd)); + end + +endfunction + function sd = long_tests(quick_check=0) num_cores = 4; K = 10; @@ -285,9 +318,18 @@ function sd = long_tests(quick_check=0) trainvec_hpf150 = surf_train_120_hpf150(1:NtrainVec,1:K); testvec_hpf150 = surf_all_hpf150(1:NtestVec,1:K); - sim_in.trainvec = trainvec; sim_in.testvec = testvec; sim_in.Nvec = 64; sim_in.train_func = 'vq_search_mse'; + sim_in.trainvec = trainvec; sim_in.testvec = testvec; + sim_in.Nvec = 64; + sim_in.train_func_short = "mse"; + sim_in.tests = ["mse"; "gain"; "mag"]; sim_in_vec(1:3) = sim_in; sim_in_vec(2).Nvec = 128; sim_in_vec(3).Nvec = 256; + + [sd desc] = train_vq_and_run_tests(sim_in(1)); + + desc + +#{ sd = pararrayfun(num_cores, @three_tests, sim_in_vec); plot_sd_results("MSE Training", 1, 0, sd) @@ -297,15 +339,16 @@ function sd = long_tests(quick_check=0) for i=1:3, sim_in_vec(i).train_func = 'vq_search_gain'; end; sd = [sd pararrayfun(num_cores, @three_tests, sim_in_vec)]; - plot_sd_results("Gain training 150Hz HPF", 3, 9, sd) + plot_sd_results("Gain training 150Hz HPF", 3, 18, sd) for i=1:3, sim_in_vec(i).train_func = 'vq_search_mse'; end; sd = [sd pararrayfun(num_cores, @three_tests, sim_in_vec)]; - plot_sd_results("Mag training 150Hz HPF", 4, 9, sd) + plot_sd_results("Mag training 150Hz HPF", 4, 27, sd) plot_sd_results2("Mag search 150Hz HPF", 5, sd) - figure(6); clf; compare_hist("Mag train Nvec=64", sd(:,3), sd(:,12), sd(:,21)); + figure(6); clf; compare_hist("Mag search Nvec=64", sd(:,3), sd(:,12), sd(:,21)); +#} endfunction @@ -431,7 +474,7 @@ rand('seed',1); % kmeans using rand for initial population, % we want same results on every run %short_detailed_test('vq_search_mag', 'vq_search_mag'); -sd = long_tests(quick_check=0); +sd = long_tests(quick_check=1); %test_training_mag