10 fois la validation croisée en SVM un contre tous (en utilisant LibSVM)

je veux faire une validation croisée de 10 fois dans ma classification one-against-all 151930920" support vector machine dans MATLAB.

j'ai essayé de mélanger ces deux réponses:

  • classification à classes multiples dans libsvm
  • exemple de classification 10 fois SVM dans le MATLAB

mais comme je suis nouveau à MATLAB et sa syntaxe, Je n'ai pas réussi à le faire fonctionner jusqu'à maintenant.

d'un autre côté, je n'ai vu que les quelques lignes suivantes sur la validation croisée dans les fichiers LibSVM README et je n'ai pas pu trouver d'exemple correspondant:

option-v divise au hasard les données en n Parties et calcule croix précision de la validation / erreur quadratique moyenne à leur sujet.

voir libsvm FAQ pour la signification des sorties.

est-ce que quelqu'un pourrait me fournir un exemple de 10-fold cross-validation et de classification un-contre-tous?

10
demandé sur Community 2012-12-24 22:45:10

2 réponses

il y a principalement deux raisons pour lesquelles nous faisons validation croisée :

  • comme une méthode d'essai qui nous donne une estimation presque non biaisée de la puissance de généralisation de notre modèle (en évitant le sur-Ajustement)
  • comme un moyen de sélection de modèle (par exemple: trouver les meilleurs paramètres C et gamma sur les données de formation, voir ce post pour un exemple)

Pour le premier cas qui nous intéresse, le processus implique la formation k modèles pour chaque pli, puis la formation d'un modèle final sur l'ensemble de l'ensemble d'apprentissage. Nous indiquons la précision moyenne sur les plis de K.

maintenant, puisque nous utilisons une approche un vs-tous pour traiter le problème multi-classe, chaque modèle se compose de N support machines vectorielles (un pour chaque classe).


Les éléments suivants sont des fonctions wrapper mise en œuvre de celui-vs-all approche:

function mdl = libsvmtrain_ova(y, X, opts)
    if nargin < 3, opts = ''; end

    %# classes
    labels = unique(y);
    numLabels = numel(labels);

    %# train one-against-all models
    models = cell(numLabels,1);
    for k=1:numLabels
        models{k} = libsvmtrain(double(y==labels(k)), X, strcat(opts,' -b 1 -q'));
    end
    mdl = struct('models',{models}, 'labels',labels);
end

function [pred,acc,prob] = libsvmpredict_ova(y, X, mdl)
    %# classes
    labels = mdl.labels;
    numLabels = numel(labels);

    %# get probability estimates of test instances using each 1-vs-all model
    prob = zeros(size(X,1), numLabels);
    for k=1:numLabels
        [~,~,p] = libsvmpredict(double(y==labels(k)), X, mdl.models{k}, '-b 1 -q');
        prob(:,k) = p(:, mdl.models{k}.Label==1);
    end

    %# predict the class with the highest probability
    [~,pred] = max(prob, [], 2);
    %# compute classification accuracy
    acc = mean(pred == y);
end

et voici les fonctions pour supporter la validation croisée:

function acc = libsvmcrossval_ova(y, X, opts, nfold, indices)
    if nargin < 3, opts = ''; end
    if nargin < 4, nfold = 10; end
    if nargin < 5, indices = crossvalidation(y, nfold); end

    %# N-fold cross-validation testing
    acc = zeros(nfold,1);
    for i=1:nfold
        testIdx = (indices == i); trainIdx = ~testIdx;
        mdl = libsvmtrain_ova(y(trainIdx), X(trainIdx,:), opts);
        [~,acc(i)] = libsvmpredict_ova(y(testIdx), X(testIdx,:), mdl);
    end
    acc = mean(acc);    %# average accuracy
end

function indices = crossvalidation(y, nfold)
    %# stratified n-fold cros-validation
    %#indices = crossvalind('Kfold', y, nfold);  %# Bioinformatics toolbox
    cv = cvpartition(y, 'kfold',nfold);          %# Statistics toolbox
    indices = zeros(size(y));
    for i=1:nfold
        indices(cv.test(i)) = i;
    end
end

enfin, voici une démo simple pour illustrer l'usage:

%# laod dataset
S = load('fisheriris');
data = zscore(S.meas);
labels = grp2idx(S.species);

%# cross-validate using one-vs-all approach
opts = '-s 0 -t 2 -c 1 -g 0.25';    %# libsvm training options
nfold = 10;
acc = libsvmcrossval_ova(labels, data, opts, nfold);
fprintf('Cross Validation Accuracy = %.4f%%\n', 100*mean(acc));

%# compute final model over the entire dataset
mdl = libsvmtrain_ova(labels, data, opts);

comparez cela avec l'approche one-vs-one qui est utilisée par défaut par libsvm:

acc = libsvmtrain(labels, data, sprintf('%s -v %d -q',opts,nfold));
model = libsvmtrain(labels, data, strcat(opts,' -q'));
15
répondu Amro 2017-05-23 12:17:00

il se peut que vous soyez confus que l'une des deux questions ne concerne pas LIBSVM. Vous devriez essayer d'ajuster cette réponse et ignorer l'autre.

Vous devez sélectionner les plis, et faire le reste exactement comme lié question. Supposons que les données ont été chargées dans data et les étiquettes dans labels :

n = size(data,1);
ns = floor(n/10);
for fold=1:10,
    if fold==1,
        testindices= ((fold-1)*ns+1):fold*ns;
        trainindices = fold*ns+1:n;
    else
        if fold==10,
            testindices= ((fold-1)*ns+1):n;
            trainindices = 1:(fold-1)*ns;
        else
            testindices= ((fold-1)*ns+1):fold*ns;
            trainindices = [1:(fold-1)*ns,fold*ns+1:n];
         end
    end
    % use testindices only for testing and train indices only for testing
    trainLabel = label(trainindices);
    trainData = data(trainindices,:);
    testLabel = label(testindices);
    testData = data(testindices,:)
    %# train one-against-all models
    model = cell(numLabels,1);
    for k=1:numLabels
        model{k} = svmtrain(double(trainLabel==k), trainData, '-c 1 -g 0.2 -b 1');
    end

    %# get probability estimates of test instances using each model
    prob = zeros(size(testData,1),numLabels);
    for k=1:numLabels
        [~,~,p] = svmpredict(double(testLabel==k), testData, model{k}, '-b 1');
        prob(:,k) = p(:,model{k}.Label==1);    %# probability of class==k
    end

    %# predict the class with the highest probability
    [~,pred] = max(prob,[],2);
    acc = sum(pred == testLabel) ./ numel(testLabel)    %# accuracy
    C = confusionmat(testLabel, pred)                   %# confusion matrix
end
2
répondu carlosdc 2017-05-23 12:34:07