/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.BatchVersion;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifyCommandLineUtil;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.UI;
import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.RandomSplitter;
import edu.cmu.minorthird.classify.experiments.StratifiedCrossValSplitter;
import edu.cmu.minorthird.classify.multi.InstanceFromPrediction;
import edu.cmu.minorthird.classify.multi.MultiClassLabel;
import edu.cmu.minorthird.classify.multi.MultiClassifier;
import edu.cmu.minorthird.classify.multi.MultiDataset;
import edu.cmu.minorthird.classify.multi.MultiDatasetClassifierTeacher;
import edu.cmu.minorthird.classify.multi.MultiExample;
import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner;
import edu.cmu.minorthird.classify.sequential.DatasetSequenceClassifierTeacher;
import edu.cmu.minorthird.classify.sequential.GenericCollinsLearner;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.classify.transform.FrequencyBasedTransformLearner;
import edu.cmu.minorthird.classify.transform.InfoGainTransformLearner2;
import edu.cmu.minorthird.classify.transform.T1InstanceTransformLearner;
import edu.cmu.minorthird.classify.transform.TFIDFTransformLearner;
import edu.cmu.minorthird.classify.transform.TransformingBatchLearner;
import edu.cmu.minorthird.util.BasicCommandLineProcessor;
import edu.cmu.minorthird.util.CommandLineProcessor;
import edu.cmu.minorthird.util.IOUtil;
import edu.cmu.minorthird.util.JointCommandLineProcessor;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.Version;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Console;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TypeSelector;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import java.awt.Component;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.event.ActionEvent;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.Iterator;
import javax.swing.AbstractAction;
import javax.swing.JButton;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JProgressBar;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

public class Train {
    private static Logger log = Logger.getLogger(UI.class);
    private static final Class<?>[] SELECTABLE_TYPES = new Class[]{DataClassificationTask.class, ClassifyCommandLineUtil.SimpleTrainParams.class, ClassifyCommandLineUtil.MultiTrainParams.class, ClassifyCommandLineUtil.SeqTrainParams.class, ClassifyCommandLineUtil.Learner.SequentialLearner.class, ClassifyCommandLineUtil.Learner.ClassifierLearner.class, KnnLearner.class, NaiveBayes.class, VotedPerceptron.class, SVMLearner.class, DecisionTreeLearner.class, AdaBoost.class, BatchVersion.class, TransformingBatchLearner.class, MaxEntLearner.class, FrequencyBasedTransformLearner.class, InfoGainTransformLearner2.class, T1InstanceTransformLearner.class, TFIDFTransformLearner.class, CollinsPerceptronLearner.class, GenericCollinsLearner.class, CrossValSplitter.class, RandomSplitter.class, StratifiedCrossValSplitter.class};

    public static void main(String[] args) {
        new DataClassificationTask().callMain(args);
    }

    public static class DataClassificationTask
    implements CommandLineProcessor.Configurable,
    Console.Task {
        private ClassifyCommandLineUtil.TrainParams trainParams = new ClassifyCommandLineUtil.TrainParams();
        public Object resultToShow;
        public boolean useGUI;
        public Console.Task main;

        public ClassifyCommandLineUtil.TrainParams getTrainParams() {
            return this.trainParams;
        }

        public void setTrainParams(ClassifyCommandLineUtil.TrainParams train) {
            this.trainParams = train;
        }

        public String getTrainParamsHelp() {
            return "Define what type of experiment you would like to run: <br>Simple - Standard classify experiment <br> Multi  - Classify Experiment with Multiple labels per example <br>Seq    - Classify experiment with a Sequential Dataset, where each example has a history, <br>           and uses a Sequential Learner";
        }

        public String getDatasetFilename() {
            return this.trainParams.trainDataFilename;
        }

        public CommandLineProcessor getCLP() {
            JointCommandLineProcessor jlpTrain = new JointCommandLineProcessor(new CommandLineProcessor[]{new GUIParams(), this.trainParams, this.trainParams});
            return jlpTrain;
        }

        public boolean getLabels() {
            return this.getDatasetFilename() != null;
        }

        public MultiDataset annotateData(MultiDataset md) {
            MultiDataset annotatedDataset = new MultiDataset();
            CrossValSplitter<MultiExample> splitter = new CrossValSplitter<MultiExample>(9);
            MultiDataset.MultiSplit s = md.MultiSplit(splitter);
            for (int x = 0; x < 9; ++x) {
                MultiDatasetClassifierTeacher teacher = new MultiDatasetClassifierTeacher(s.getTrain(x));
                MultiClassifier c = teacher.train(this.trainParams.clsLnr.clsLearner);
                Iterator<MultiExample> i = s.getTest(x).multiIterator();
                while (i.hasNext()) {
                    MultiExample ex = i.next();
                    Instance instance = ex.asInstance();
                    MultiClassLabel predicted = c.multiLabelClassification(instance);
                    InstanceFromPrediction annotatedInstance = new InstanceFromPrediction(instance, predicted.bestClassName());
                    MultiExample newEx = new MultiExample((Instance)annotatedInstance, ex.getMultiLabel(), ex.getWeight());
                    annotatedDataset.addMulti(newEx);
                }
            }
            return annotatedDataset;
        }

        public void doMain() {
            if (this.trainParams.trainData == null) {
                System.out.println("The training data needs to be specified with the -data option.");
                return;
            }
            if (this.trainParams.typeString.equals("seq") && !(this.trainParams.trainData instanceof SequenceDataset)) {
                System.out.println("The training data should be a sequence dataset");
                return;
            }
            if (this.trainParams.showData) {
                new ViewerFrame("Training data", this.trainParams.trainData.toGUI());
            }
            if (this.trainParams.typeString.equals("seq")) {
                DatasetSequenceClassifierTeacher teacher = new DatasetSequenceClassifierTeacher((SequenceDataset)this.trainParams.trainData);
                SequenceClassifier c = teacher.train(this.trainParams.seqLnr.seqLearner);
                this.trainParams.resultToShow = this.trainParams.resultToSave = c;
            } else if (this.trainParams.typeString.equals("multi")) {
                MultiDataset multiData = this.trainParams.crossDim ? this.annotateData((MultiDataset)this.trainParams.trainData) : (MultiDataset)this.trainParams.trainData;
                MultiDatasetClassifierTeacher teacher = new MultiDatasetClassifierTeacher(multiData);
                MultiClassifier c = teacher.train(this.trainParams.clsLnr.clsLearner);
                this.trainParams.resultToShow = this.trainParams.resultToSave = c;
            } else {
                DatasetClassifierTeacher teacher = new DatasetClassifierTeacher(this.trainParams.trainData);
                Classifier c = teacher.train(this.trainParams.clsLnr.clsLearner);
                this.trainParams.resultToShow = this.trainParams.resultToSave = c;
            }
            this.resultToShow = this.trainParams.resultToShow;
            if (this.trainParams.saveAs != null) {
                if (IOUtil.saveSomehow(this.trainParams.resultToSave, this.trainParams.saveAs)) {
                    log.info("Result saved in " + this.trainParams.saveAs);
                } else {
                    log.error("Can't save " + this.trainParams.resultToSave.getClass() + " to " + this.trainParams.saveAs);
                }
            }
            if (this.trainParams.showResult) {
                new ViewerFrame("Result", new SmartVanillaViewer(this.trainParams.resultToShow));
            }
            if (this.trainParams.saveAs != null) {
                if (IOUtil.saveSomehow(this.trainParams.resultToSave, this.trainParams.saveAs)) {
                    log.info("Result saved in " + this.trainParams.saveAs);
                } else {
                    log.error("Can't save " + this.trainParams.resultToSave.getClass() + " to " + this.trainParams.saveAs);
                }
            }
        }

        public Object getMainResult() {
            return this.resultToShow;
        }

        public void callMain(final String[] args) {
            try {
                this.getCLP().processArguments(args);
                if (!this.useGUI) {
                    this.doMain();
                } else {
                    this.main = this;
                    ComponentViewer v = new ComponentViewer(){
                        static final long serialVersionUID = 20080128L;

                        public JComponent componentFor(Object o) {
                            TypeSelector ts = new TypeSelector(SELECTABLE_TYPES, "selectableTypes.txt", DataClassificationTask.class);
                            ts.setContent(o);
                            JPanel panel = new JPanel();
                            panel.setBorder(new TitledBorder(StringUtil.toString(args, "Command line: ", "", " ")));
                            panel.setLayout(new GridBagLayout());
                            JPanel subpanel1 = new JPanel();
                            subpanel1.setBorder(new TitledBorder("Parameter modification"));
                            subpanel1.add(ts);
                            GridBagConstraints gbc = Viewer.fillerGBC();
                            gbc.weighty = 0.0;
                            panel.add((Component)subpanel1, gbc);
                            JPanel subpanel2 = new JPanel();
                            subpanel2.setBorder(new TitledBorder("Execution controls"));
                            JButton viewButton = new JButton(new AbstractAction("View results"){
                                static final long serialVersionUID = 20080128L;

                                public void actionPerformed(ActionEvent event) {
                                    SmartVanillaViewer rv = new SmartVanillaViewer();
                                    rv.setContent(DataClassificationTask.this.getMainResult());
                                    new ViewerFrame("Result", rv);
                                }
                            });
                            viewButton.setEnabled(false);
                            JPanel errorPanel = new JPanel();
                            errorPanel.setBorder(new TitledBorder("Error messages and output"));
                            final Console console = new Console(DataClassificationTask.this.main, DataClassificationTask.this.getDatasetFilename() != null, viewButton);
                            errorPanel.add(console.getMainComponent());
                            JButton goButton = new JButton(new AbstractAction("Start task"){
                                static final long serialVersionUID = 20080128L;

                                public void actionPerformed(ActionEvent event) {
                                    console.start();
                                }
                            });
                            JButton showLabelsButton = new JButton(new AbstractAction("Show train data"){
                                static final long serialVersionUID = 20080128L;

                                public void actionPerformed(ActionEvent ev) {
                                    new ViewerFrame("Labeled TextBase", new SmartVanillaViewer(((DataClassificationTask)DataClassificationTask.this).trainParams.trainData));
                                }
                            });
                            JButton clearButton = new JButton(new AbstractAction("Clear window"){
                                static final long serialVersionUID = 20080128L;

                                public void actionPerformed(ActionEvent ev) {
                                    console.clear();
                                }
                            });
                            JButton helpParamsButton = new JButton(new AbstractAction("Parameters"){
                                static final long serialVersionUID = 20080128L;

                                public void actionPerformed(ActionEvent ev) {
                                    PrintStream oldSystemOut = System.out;
                                    ByteArrayOutputStream outBuffer = new ByteArrayOutputStream();
                                    System.setOut(new PrintStream(outBuffer));
                                    console.append(outBuffer.toString());
                                    System.setOut(oldSystemOut);
                                }
                            });
                            subpanel2.add(goButton);
                            subpanel2.add(viewButton);
                            subpanel2.add(showLabelsButton);
                            subpanel2.add(clearButton);
                            subpanel2.add(new JLabel("Help:"));
                            subpanel2.add(helpParamsButton);
                            gbc = Viewer.fillerGBC();
                            gbc.weighty = 0.0;
                            gbc.gridy = 1;
                            panel.add((Component)subpanel2, gbc);
                            gbc = Viewer.fillerGBC();
                            gbc.weighty = 1.0;
                            gbc.gridy = 2;
                            panel.add((Component)errorPanel, gbc);
                            JProgressBar progressBar1 = new JProgressBar();
                            JProgressBar progressBar2 = new JProgressBar();
                            JProgressBar progressBar3 = new JProgressBar();
                            ProgressCounter.setGraphicContext(new JProgressBar[]{progressBar1, progressBar2, progressBar3});
                            gbc = Viewer.fillerGBC();
                            gbc.weighty = 0.0;
                            gbc.gridy = 3;
                            panel.add((Component)progressBar1, gbc);
                            gbc = Viewer.fillerGBC();
                            gbc.weighty = 0.0;
                            gbc.gridy = 4;
                            panel.add((Component)progressBar2, gbc);
                            gbc = Viewer.fillerGBC();
                            gbc.weighty = 0.0;
                            gbc.gridy = 5;
                            panel.add((Component)progressBar3, gbc);
                            return panel;
                        }
                    };
                    v.setContent(this);
                    String className = this.getClass().toString().substring("class ".length());
                    new ViewerFrame(className + ": " + Version.getVersion(), v);
                }
            }
            catch (Exception e) {
                e.printStackTrace();
                System.out.println("Use option -help for help");
            }
        }

        protected class GUIParams
        extends BasicCommandLineProcessor {
            protected GUIParams() {
            }

            public void gui() {
                DataClassificationTask.this.useGUI = true;
                if (ClassifyCommandLineUtil.TrainParams.type != null) {
                    DataClassificationTask.this.trainParams = ClassifyCommandLineUtil.TrainParams.type;
                } else {
                    DataClassificationTask.this.trainParams = new ClassifyCommandLineUtil.SimpleTrainParams();
                }
            }

            public void usage() {
                System.out.println("presentation parameters:");
                System.out.println(" -gui                     use graphic interface to set parameters");
                System.out.println();
            }
        }
    }
}

