/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.linear;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.LinkedHashMap;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXNode;

public class LinearSGDModel
extends AbstractLinearSGDModel<Label>
implements ONNXExportable {
    private static final long serialVersionUID = 2L;
    private final VectorNormalizer normalizer;
    @Deprecated
    private DenseMatrix weights = null;

    LinearSGDModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, LinearParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities) {
        super(name, provenance, featureIDMap, outputIDInfo, parameters, generatesProbabilities);
        this.normalizer = normalizer;
    }

    public Prediction<Label> predict(Example<Label> example) {
        AbstractSGDModel.PredAndActive predTuple = this.predictSingle(example);
        DenseVector prediction = predTuple.prediction;
        prediction.normalize(this.normalizer);
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> predMap = new LinkedHashMap<String, Label>();
        for (int i = 0; i < prediction.size(); ++i) {
            String labelName = ((Label)this.outputIDInfo.getOutput(i)).getLabel();
            double score = prediction.get(i);
            Label label = new Label(labelName, score);
            predMap.put(labelName, label);
            if (!(score > maxScore)) continue;
            maxScore = score;
            maxLabel = label;
        }
        return new Prediction(maxLabel, predMap, predTuple.numActiveFeatures - 1, example, this.generatesProbabilities);
    }

    protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) {
        return new LinearSGDModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<Label>)this.outputIDInfo, (LinearParameters)this.modelParameters.copy(), this.normalizer, this.generatesProbabilities);
    }

    protected String getDimensionName(int index) {
        return ((Label)this.outputIDInfo.getOutput(index)).getLabel();
    }

    protected ONNXNode onnxOutput(ONNXNode input) {
        return this.normalizer.exportNormalizer(input);
    }

    protected String onnxModelName() {
        return "Classification-LinearSGDModel";
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        if (this.weights != null && this.modelParameters == null) {
            this.modelParameters = new LinearParameters(this.weights);
            this.weights = null;
            this.addBias = true;
        }
    }
}

