From d225601948727309e0b5bf289a83c8a1c42c5917 Mon Sep 17 00:00:00 2001 From: drowe67 Date: Thu, 6 Jul 2017 09:45:55 +0000 Subject: [PATCH] added a GUI mode for stepping thru frames to help evaluate VQ git-svn-id: https://svn.code.sf.net/p/freetel/code@3279 01035d8c-6547-0410-b346-abe4f91aad63 --- codec2-dev/octave/kmeans_tests.m | 247 +++++++++++++++++++++++++------ 1 file changed, 200 insertions(+), 47 deletions(-) diff --git a/codec2-dev/octave/kmeans_tests.m b/codec2-dev/octave/kmeans_tests.m index 24159ca9..9b4d8837 100644 --- a/codec2-dev/octave/kmeans_tests.m +++ b/codec2-dev/octave/kmeans_tests.m @@ -17,7 +17,7 @@ %---------------------------------------------------------------------- % standard mean squared error search -function [idx contrib errors test_ g mg] = vq_search_mse(vq, data) +function [idx contrib errors test_ g mg sl] = vq_search_mse(vq, data) [nVec nCols] = size(vq); nRows = rows(data); @@ -38,14 +38,14 @@ function [idx contrib errors test_ g mg] = vq_search_mse(vq, data) test_(f,:) = vq(min_ind,:); end - g = mg = 1; % dummys for this function + g = 0; mg = 1; sl = 0; % dummys for this function endfunction %---------------------------------------------------------------------- % abs() search with a linear gain term -function [idx contrib errors test_ g mg] = vq_search_gain(vq, data) +function [idx contrib errors test_ g mg sl] = vq_search_gain(vq, data) [nVec nCols] = size(vq); nRows = rows(data); @@ -79,14 +79,14 @@ function [idx contrib errors test_ g mg] = vq_search_gain(vq, data) errors(f) = mn; contrib(f,:) = test_(f,:) = vq(min_ind,:) + g(f,min_ind); end - mg = 1; + mg = 1; sl = 0; endfunction %---------------------------------------------------------------------- % abs() search with a linear plus ampl scaling term -function [idx contrib errors test_ g mg] = vq_search_mag(vq, data) +function [idx contrib errors test_ g mg sl] = vq_search_mag(vq, data) [nVec nCols] = size(vq); nRows = rows(data); @@ -128,7 +128,7 @@ function [idx contrib errors test_ g mg] = vq_search_mag(vq, data) contrib(f,:) = test_(f,:) = mg(f,min_ind) * vq(min_ind,:) + g(f,min_ind); end - + sl = 0; endfunction @@ -202,8 +202,8 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 0) [nRows nCols] = size(test); - [idx contrib errors test_ g mg] = feval(search_func, vq, test); - + [idx contrib errors test_ g mg sl] = feval(search_func, vq, test); + % sd over time sd_per_frame = zeros(nRows,1); @@ -211,44 +211,111 @@ function sd_per_frame = run_test(vq, test, nVec, search_func, gui_en = 0) sd_per_frame(i) = mean(abs(test(i,:) - test_(i,:))); end + printf("average sd: %3.2f\n", mean(sd_per_frame)); %printf("%18s nVec: %d SD: %3.2f dB\n", search_func, nVec, mean(sd_per_frame)); - % plots sd and errors over time + % Optional GUI with U to anlayse VQ perf --------------------------------------------------- if gui_en - figure(1); clf; subplot(211); plot(sd_per_frame); title('SD'); subplot(212); plot(errors); title('mean error'); + gui_mode = "sort"; % "sort" or "fbf" + sort_order = "worst"; % "worst" or "best" + f = 1; + + % plots sd over time and histogram + + figure(1); clf; subplot(211); plot(sd_per_frame); title('SD'); + subplot(212); [yy xx] = hist(sd_per_frame); plot(xx,yy,'+-'); + end + + while gui_en % display m frames, printing some stats, plotting vector to give visual idea of match figure(2); clf; - [errors_dec frame_dec] = sort(errors, "descend"); - m = 4; + + % try this in 'ascend' or 'descend' mode to best or worst frames + + if strcmp(gui_mode, "sort") + if strcmp(sort_order, "best") + [errors_dec frame_dec] = sort(errors, "ascend"); + end + if strcmp(sort_order, "worst") + [errors_dec frame_dec] = sort(errors, "descend"); + end + m = 4; + else + m = 1; + errors_dec = errors(f); frame_dec(1) = f; + end + for i=1:m + + % build up VQ legend + af = frame_dec(i); aind = idx(af); - l = sprintf("idx: %d", aind); - ag = 0; amg = 1; - if strcmp(search_func, "vq_search_gain") || strcmp(search_func, "vq_search_mag") + + ag = 0; amg = 1; asl = 0; + if strcmp(search_func, "vq_search_gain") ag = g(af,aind); - l = sprintf("%s g: %3.2f", l, ag); + l3 = sprintf("g: %3.2f", ag); end if strcmp(search_func, "vq_search_mag") + ag = g(af,aind); + amg = mg(af,aind); + l3 = sprintf("g: %3.2f mg: %3.2f", ag, amg); + end + if strcmp(search_func, "vq_search_slope") + ag = g(af,aind); amg = mg(af,aind); - l = sprintf("%s mg: %3.2f", l, amg); + asl = sl(af,aind); + l3 = sprintf("g: %3.2f mg: %3.2f sl: %3.2f", ag, amg, asl); end - %printf("%d f: %d %s\n", i, af, l); - subplot(sqrt(m),sqrt(m),i); - l1 = sprintf("b-;fr %d;", af); + % plot target + + subplot(sqrt(m), sqrt(m),i); + l1 = sprintf("b-;fr %d sd: %3.2f;", af, errors_dec(m)); plot(test(af,:), l1); + % plot vq vector and modified version + hold on; l2 = sprintf("g-+;ind %d;", aind); plot(vq(aind, :), l2); - l3 = sprintf("g-o;%s;",l); - plot(amg*vq(aind, :) + ag, l3); + l3 = sprintf("g-o;%s;",l3); + plot(amg*vq(aind, :) + ag + asl*(1:nCols), l3); hold off; axis([1 nCols -10 40]); end + + % interactive menu ------------------------------------------ + + if strcmp(gui_mode, "sort") + printf("\rmenu: m-mode[%s] o-order[%s] q-quit", gui_mode, sort_order); + end + if strcmp(gui_mode, "fbf") + printf("\rmenu: m-mode[%s] frame: %d n-next b-back q-quit", gui_mode, f); + end + fflush(stdout); + k = kbhit(); + + if k == 'm' + if strcmp(gui_mode, "sort") + gui_mode = "fbf"; f = frame_dec(1); + else + gui_mode = "sort"; + end; + end + if k == 'o' + if strcmp(sort_order, "worst") + sort_order = "best"; + else + sort_order = "worst"; + end; + end + if k == 'n', f = f + 1; endif; + if k == 'b', f = f - 1; endif; + if k == 'q', gui_en = 0; printf("\n"); endif; end endfunction @@ -277,10 +344,13 @@ function [sd des] = train_vq_and_run_tests(sim_in) Nvec = sim_in.Nvec; train_func = get_search_func(sim_in.train_func_short); - [idx vq] = kmeans(sim_in.trainvec, Nvec, - "start", "sample", - "emptyaction", "singleton", - "search_func", train_func); + printf(" Nvec %d kmeans start\n", Nvec); + [idx vq] = feval(sim_in.kmeans, + sim_in.trainvec, Nvec, + "start", "sample", + "emptyaction", "singleton", + "search_func", train_func); + printf(" Nvec %d kmeans end\n", Nvec); sd = []; des = []; tests = sim_in.tests; for t=1:length(cellstr(tests)) @@ -323,7 +393,7 @@ function plot_sd_results(fg, title_str, sd, desc, train_list, test_list) for j=1:length(cellstr(test_list)) test_func_short = char(cellstr(test_list)(j)); [y x] = search_tests(sd, desc, train_func_short, test_func_short); - leg = sprintf("o-%d;train: %5s test: %5s;", nlines, train_func_short, test_func_short) + leg = sprintf("o-%d;train: %5s test: %5s;", nlines, train_func_short, test_func_short); if nlines, hold on; end; x += inc; inc += 0.1; % separate x coords a bit to make errors bars legible errorbar(x, y(:,1), y(:,2), leg); @@ -372,8 +442,15 @@ end %---------------------------------------------------------------------- % Run a bunch of long tests in parallel and plot results +% + +% This test tries a bunch of training and test search combinations on +% the first 1kHz to see if there is any advantage in using the higher +% order search techniques in the VQ training. So far it appears +% not.... However the higher order search routines do give great +% results compared to a basic MSE (or oder) search. -function [sd desc] = long_tests(quick_check=0) +function [sd desc] = long_tests_1_10(quick_check=0) num_cores = 4; K = 10; load surf_train_120_hpf150; load surf_all_hpf150; @@ -391,7 +468,8 @@ function [sd desc] = long_tests(quick_check=0) % build up a big array of tests to run ------------------------ - sim_in.trainvec = trainvec; sim_in.testvec = testvec; + sim_in.trainvec = trainvec; sim_in.testvec = testvec; + sim_in.kmeans = "kmeans2"; % slow but supports different search functions sim_in.Nvec = 64; sim_in.train_func_short = "mse"; sim_in.tests = ["mse"; "gain"; "mag"; "slope"]; @@ -421,7 +499,8 @@ function [sd desc] = long_tests(quick_check=0) % run test list in parallel [sd desc] = pararrayfun(num_cores, @train_vq_and_run_tests, test_list); - + %train_vq_and_run_tests(test_list(1)); + % Plot results ----------------------------------------------- fg = 1; @@ -446,27 +525,98 @@ function [sd desc] = long_tests(quick_check=0) endfunction -function short_detailed_test(train_func, test_func) - K = 10; +%---------------------------------------------------------------------- +% This test tries some test combinations on the 1000 to 3000 Hz range. +% +% Based on our results with long_test_1_10() we just use MSE training + +function [sd desc] = long_tests_11_40(quick_check=0) + num_cores = 4; + K = 20; st = 11; en = 40; + load surf_train_120_hpf150; load surf_all_hpf150; + + if quick_check + NtrainVec = 1000; + NtestVec = 100; + else + NtrainVec = length(surf_train_120_hpf150); + NtestVec = length(surf_all_hpf150); + end + + trainvec = surf_train_120_hpf150(1:NtrainVec,st:en); + testvec = surf_all_hpf150(1:NtestVec,st:en); + + % build up an array of tests to run ------------------------ + + sim_in.kmeans = "kmeans"; + sim_in.trainvec = trainvec; sim_in.testvec = testvec; + sim_in.Nvec = 64; + sim_in.train_func_short = "mse"; + sim_in.tests = ["slope"]; + + % Test1: mse training, 64, 128, 256, 512 + + sim_in_vec(1:4) = sim_in; + sim_in_vec(2).Nvec = 128; sim_in_vec(3).Nvec = 256; sim_in_vec(4).Nvec = 512; + + test_list = sim_in_vec; + + % run test list in parallel + + [sd desc] = pararrayfun(num_cores, @train_vq_and_run_tests, test_list); + %[sd desc] = train_vq_and_run_tests(sim_in_vec(4)); + + % Plot results ----------------------------------------------- + + fg = 2; + plot_sd_results(fg++, "MSE Training 10..30", sd, desc, "mse", "slope"); + +#{ + % histogram of results from Nvec=64, all training methods, slope search + + testnum = search_tests_Nvec(sd, desc, "mse", "slope", 64); + hist_sd = sd(:,testnum); hist_desc = desc(testnum); + testnum = search_tests_Nvec(sd, desc, "gain", "slope", 64); + hist_sd = [hist_sd sd(:,testnum)]; hist_desc = [hist_desc desc(testnum)]; + testnum = search_tests_Nvec(sd, desc, "mag", "slope", 64); + hist_sd = [hist_sd sd(:,testnum)]; hist_desc = [hist_desc desc(testnum)]; + testnum = search_tests_Nvec(sd, desc, "slope", "slope", 64); + hist_sd = [hist_sd sd(:,testnum)]; hist_desc = [hist_desc desc(testnum)]; + compare_hist(fg, "Histogram of SDs Nvec=64", hist_sd, hist_desc); +#} +endfunction + + +function vq = detailed_test(Nvec=64, st=1, en=10, quick_check=1, test_func = "vq_search_slope") + K = en-st+1; load surf_train_120_hpf150; - load surf_all; - NtrainVec = 1000; - NtestVec = 100; - trainvec = surf_train_120_hpf150(1:NtrainVec,1:K); - testvec = surf_all(1:NtestVec,1:K); + load surf_all_hpf150; + if quick_check + NtrainVec = 1000; + NtestVec = 100; + else + NtrainVec = length(surf_train_120_hpf150); + NtestVec = length(surf_all_hpf150); + NtestVec = 100; + end - Nvec = 64; % we can plot all vectors on one screen of subplots - - [idx vq] = kmeans2(trainvec, Nvec, - "start", "sample", - "emptyaction", "singleton", - "search_func", train_func); + trainvec = surf_train_120_hpf150(1:NtrainVec,st:en); + testvec = surf_all_hpf150(1:NtestVec,st:en); - sdpf = run_test(vq, testvec, Nvec, test_func, gui_en=1); + [idx vq] = kmeans(trainvec, Nvec, + "start", "sample", + "emptyaction", "singleton"); + + run_test(vq, testvec, Nvec, test_func, gui_en=1); endfunction -% Some contrived examples to test VQ training +% ------------------------------------------------------------------ +% +% Following test_* functions have some contrived examples to test VQ +% training with each training method. However the longtest_1_10 +% results suggest these training methds don't provide any improvement +% over standard MSE kmeans. function test_training_mse K = 3; NtrainVec = 10; Nvec = 2; @@ -604,8 +754,11 @@ format; more off; rand('seed',1); % kmeans using rand for initial population, % we want same results on every run -%short_detailed_test('vq_search_mag', 'vq_search_mag'); -[sd desc] = long_tests(quick_check=0); +% choose one of these to run + +detailed_test(64, 1, 20, quick_check=1); +%[sd desc] = long_tests_1_10(quick_check=1); +%[sd desc] = long_tests_11_40(quick_check=0); %test_training_slope -- 2.25.1