Package hivemall.io

Examples of hivemall.io.PredictionResult


    @Override
    protected void train(Collection<?> features, float target) {
        preTrain(target);

        PredictionResult margin = calcScoreAndNorm(features);
        float predicted = margin.getScore();
        float loss = loss(target, predicted);

        if(loss > 0.f) {
            int sign = (target - predicted) > 0.f ? 1 : -1; // sign(y - (W^t)x)
            float eta = eta(loss, margin); // min(C, loss / |x|^2)
View Full Code Here


                maxScore = score;
                maxScoredLabel = label;
            }
        }

        return new PredictionResult(maxScoredLabel, maxScore);
    }
View Full Code Here

        }

        for(Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
            Object label = label2map.getKey();
            PredictionModel model = label2map.getValue();
            PredictionResult predicted = calcScoreAndVariance(model, features);
            float score = predicted.getScore();

            if(label.equals(actual_label)) {
                correctScore = score;
                correctVariance = predicted.getVariance();
            } else {
                if(maxAnotherLabel == null || score > maxAnotherScore) {
                    maxAnotherLabel = label;
                    maxAnotherScore = score;
                    maxAnotherVariance = predicted.getVariance();
                }
            }
        }

        float var = correctVariance + maxAnotherVariance;
View Full Code Here

                score += (old_w.get() * v);
                variance += (old_w.getCovariance() * v * v);
            }
        }

        return new PredictionResult(score).variance(variance);
    }
View Full Code Here

                score += (old_w * v);
            }
            squared_norm += (v * v);
        }

        return new PredictionResult(score).squaredNorm(squared_norm);
    }
View Full Code Here

                score += (old_w.get() * v);
                variance += (old_w.getCovariance() * v * v);
            }
        }

        return new PredictionResult(score).variance(variance);
    }
View Full Code Here

    @Override
    protected void train(List<?> features, int label) {
        final float y = label > 0 ? 1.f : -1.f;

        PredictionResult margin = calcScoreAndVariance(features);
        float m = margin.getScore() * y;

        if(m < 1.f) {
            float var = margin.getVariance();
            float beta = 1.f / (var + r);
            float alpha = (1.f - m) * beta;
            update(features, y, alpha, beta);
        }
    }
View Full Code Here

        @Override
        protected void train(List<?> features, int label) {
            final float y = label > 0 ? 1.f : -1.f;

            PredictionResult margin = calcScoreAndVariance(features);
            float p = margin.getScore();
            float loss = loss(p, y); // C - m (m = y * p)

            if(loss > 0.f) {// m < 1.0 || 1.0 - m > 0
                float var = margin.getVariance();
                float beta = 1.f / (var + r);
                float alpha = loss * beta; // (1.f - m) * beta
                update(features, y, alpha, beta);
            }
        }
View Full Code Here

    @Override
    protected void train(final List<?> features, final int label) {
        final float y = label > 0 ? 1f : -1f;

        PredictionResult margin = calcScoreAndNorm(features);
        float p = margin.getScore();
        float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p

        if(loss > 0.f) { // y * p < 1
            float eta = eta(loss, margin);
            float coeff = eta * y;
 
View Full Code Here

    @Override
    protected void train(List<?> features, int label) {
        final float y = label > 0 ? 1f : -1f;

        PredictionResult margin = calcScoreAndVariance(features);
        float loss = loss(margin, y);

        if(loss > 0.f) {
            float alpha = getAlpha(margin);
            if(alpha == 0.f) {
View Full Code Here

TOP

Related Classes of hivemall.io.PredictionResult

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.