Package org.apache.stanbol.enhancer.topic.training

Source Code of org.apache.stanbol.enhancer.topic.training.SolrTrainingSet

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.stanbol.enhancer.topic.training;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.Dictionary;
import java.util.List;
import java.util.UUID;

import org.apache.commons.lang.StringUtils;
import org.apache.felix.scr.annotations.Activate;
import org.apache.felix.scr.annotations.Component;
import org.apache.felix.scr.annotations.ConfigurationPolicy;
import org.apache.felix.scr.annotations.Deactivate;
import org.apache.felix.scr.annotations.Properties;
import org.apache.felix.scr.annotations.Property;
import org.apache.felix.scr.annotations.Reference;
import org.apache.felix.scr.annotations.ReferenceCardinality;
import org.apache.felix.scr.annotations.ReferencePolicy;
import org.apache.felix.scr.annotations.ReferenceStrategy;
import org.apache.felix.scr.annotations.Service;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.SolrServer;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.client.solrj.util.ClientUtils;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrInputDocument;
import org.apache.stanbol.commons.solr.managed.ManagedSolrServer;
import org.apache.stanbol.enhancer.topic.Batch;
import org.apache.stanbol.enhancer.topic.ConfiguredSolrCoreTracker;
import org.apache.stanbol.enhancer.topic.UTCTimeStamper;
import org.osgi.framework.InvalidSyntaxException;
import org.osgi.service.cm.ConfigurationException;
import org.osgi.service.component.ComponentContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Implementation of the {@code TrainingSet} interface that uses a Solr Core as backend to store and retrieve
* the text examples used to train a classifier.
*/
@Component(metatype = true, immediate = true, configurationFactory = true, policy = ConfigurationPolicy.REQUIRE)
@Service
@Properties(value = {@Property(name = SolrTrainingSet.TRAINING_SET_NAME),
                     @Property(name = SolrTrainingSet.SOLR_CORE),
                     @Property(name = SolrTrainingSet.EXAMPLE_ID_FIELD, value = "id"),
                     @Property(name = SolrTrainingSet.EXAMPLE_TEXT_FIELD, value = "text"),
                     @Property(name = SolrTrainingSet.TOPICS_URI_FIELD, value = "topics"),
                     @Property(name = SolrTrainingSet.MODIFICATION_DATE_FIELD, value = "modification_dt")})
public class SolrTrainingSet extends ConfiguredSolrCoreTracker implements TrainingSet {

    public static final String TRAINING_SET_NAME = "org.apache.stanbol.enhancer.topic.trainingset.id";

    public static final String SOLR_CORE = "org.apache.stanbol.enhancer.engine.topic.solrCore";

    public static final String TOPICS_URI_FIELD = "org.apache.stanbol.enhancer.engine.topic.topicsUriField";

    public static final String EXAMPLE_ID_FIELD = "org.apache.stanbol.enhancer.engine.topic.exampleIdField";

    public static final String EXAMPLE_TEXT_FIELD = "org.apache.stanbol.enhancer.engine.topic.exampleTextField";

    public static final String MODIFICATION_DATE_FIELD = "org.apache.stanbol.enhancer.engine.topic.modificiationDateField";

    @SuppressWarnings("unused")
    private static final Logger log = LoggerFactory.getLogger(SolrTrainingSet.class);

    protected String trainingSetId;

    protected String exampleIdField;

    protected String exampleTextField;

    protected String topicUrisField;

    protected String modificationDateField;

    // TODO: make me configurable using an OSGi property
    protected int batchSize = 100;

    @Reference(cardinality = ReferenceCardinality.OPTIONAL_UNARY, bind = "bindManagedSolrServer", unbind = "unbindManagedSolrServer", strategy = ReferenceStrategy.EVENT, policy = ReferencePolicy.DYNAMIC)
    protected ManagedSolrServer managedSolrServer;

    public String getName() {
        return trainingSetId;
    }

    @Activate
    protected void activate(ComponentContext context) throws ConfigurationException, InvalidSyntaxException {
        indexArchiveName = "default-topic-trainingset";
        @SuppressWarnings("unchecked")
        Dictionary<String,Object> config = context.getProperties();
        this.context = context;
        configure(config);
    }

    @Deactivate
    public void deactivate(ComponentContext context) {
        if (indexTracker != null) {
            indexTracker.close();
        }
    }

    @Override
    public void configure(Dictionary<String,Object> config) throws ConfigurationException {
        trainingSetId = getRequiredStringParam(config, TRAINING_SET_NAME);
        exampleIdField = getRequiredStringParam(config, EXAMPLE_ID_FIELD);
        exampleTextField = getRequiredStringParam(config, EXAMPLE_TEXT_FIELD);
        topicUrisField = getRequiredStringParam(config, TOPICS_URI_FIELD);
        modificationDateField = getRequiredStringParam(config, MODIFICATION_DATE_FIELD);
        configureSolrCore(config, SOLR_CORE, trainingSetId);
    }

    public static ConfiguredSolrCoreTracker fromParameters(Dictionary<String,Object> config) throws ConfigurationException {
        ConfiguredSolrCoreTracker engine = new SolrTrainingSet();
        engine.configure(config);
        return engine;
    }

    @Override
    public boolean isUpdatable() {
        return true;
    }

    @Override
    public String registerExample(String exampleId, String text, List<String> topics) throws TrainingSetException {
        if (text == null) {
            // special case: example removal
            if (exampleId == null) {
                throw new IllegalArgumentException("exampleId and text should not be null simultaneously");
            }
            SolrServer solrServer = getActiveSolrServer();
            try {
                solrServer.deleteByQuery(exampleIdField + ":" + exampleId);
                solrServer.commit();
                return exampleId;
            } catch (Exception e) {
                String msg = String.format("Error deleting example with id '%s' on Solr Core '%s'",
                    exampleId, solrCoreId);
                throw new TrainingSetException(msg, e);
            }
        }

        if (exampleId == null || exampleId.isEmpty()) {
            exampleId = UUID.randomUUID().toString();
        }
        SolrInputDocument doc = new SolrInputDocument();
        doc.addField(exampleIdField, exampleId);
        doc.addField(exampleTextField, text);
        if (topics != null) {
            doc.addField(topicUrisField, topics);
        }
        doc.addField(modificationDateField, UTCTimeStamper.nowUtcDate());
        SolrServer server = getActiveSolrServer();
        try {
            server.add(doc);
            server.commit();
        } catch (Exception e) {
            String msg = String.format("Could not register example '%s' with topics: ['%s']", exampleId,
                StringUtils.join(topics, "', '"));
            throw new TrainingSetException(msg, e);
        }
        return exampleId;
    }

    @Override
    public boolean hasChangedSince(List<String> topics, Date referenceDate) throws TrainingSetException {
        String utcIsoDate = UTCTimeStamper.utcIsoString(referenceDate);
        StringBuffer sb = new StringBuffer();
        sb.append(modificationDateField);
        sb.append(":[");
        sb.append(utcIsoDate);
        sb.append(" TO *]");
        if (topics != null && topics.size() > 0) {
            sb.append(" AND (");
            List<String> parts = new ArrayList<String>();
            for (String topic : topics) {
                // use a nested query to avoid string escaping issues with special solr chars
                parts.add(topicUrisField + ":" + ClientUtils.escapeQueryChars(topic));
            }
            sb.append(StringUtils.join(parts, " OR "));
            sb.append(")");
        }
        SolrQuery query = new SolrQuery(sb.toString());
        query.setRows(1);
        query.setFields(exampleIdField);
        try {
            SolrServer solrServer = getActiveSolrServer();
            return solrServer.query(query).getResults().size() > 0;
        } catch (SolrServerException e) {
            String msg = String.format(
                "Error while fetching topics for examples modified after '%s' on Solr Core '%s'.",
                utcIsoDate, solrCoreId);
            throw new TrainingSetException(msg, e);
        }
    }

    @Override
    public Batch<Example> getPositiveExamples(List<String> topics, Object offset) throws TrainingSetException {
        return getExamples(topics, offset, true);
    }

    @Override
    public Batch<Example> getNegativeExamples(List<String> topics, Object offset) throws TrainingSetException {
        return getExamples(topics, offset, false);
    }

    protected Batch<Example> getExamples(List<String> topics, Object offset, boolean positive) throws TrainingSetException {
        List<Example> items = new ArrayList<Example>();
        SolrServer solrServer = getActiveSolrServer();
        SolrQuery query = new SolrQuery();
        List<String> parts = new ArrayList<String>();
        String q = "";
        if (topics.isEmpty()) {
            q += "*:*";
        } else if (positive) {
            for (String topic : topics) {
                parts.add(topicUrisField + ":" + ClientUtils.escapeQueryChars(topic));
            }
            if (offset != null) {
                q += "(";
            }
            q += StringUtils.join(parts, " OR ");
            if (offset != null) {
                q += ")";
            }
        } else {
            for (String topic : topics) {
                parts.add("-" + topicUrisField + ":" + ClientUtils.escapeQueryChars(topic));
            }
            q += StringUtils.join(parts, " AND ");
        }
        if (offset != null) {
            q += " AND " + exampleIdField + ":[" + offset.toString() + " TO *]";
        }
        query.setQuery(q);
        query.addSortField(exampleIdField, SolrQuery.ORDER.asc);
        query.set("rows", batchSize + 1);
        String nextExampleId = null;
        try {
            int count = 0;
            QueryResponse response = solrServer.query(query);
            for (SolrDocument result : response.getResults()) {
                if (count == batchSize) {
                    nextExampleId = result.getFirstValue(exampleIdField).toString();
                } else {
                    count++;
                    String exampleId = result.getFirstValue(exampleIdField).toString();
                    Collection<Object> labelValues = result.getFieldValues(topicUrisField);
                    Collection<Object> textValues = result.getFieldValues(exampleTextField);
                    if (textValues == null) {
                        continue;
                    }
                    items.add(new Example(exampleId, labelValues, textValues));
                }
            }
        } catch (SolrServerException e) {
            String msg = String.format(
                "Error while fetching positive examples for topics ['%s'] on Solr Core '%s'.",
                StringUtils.join(topics, "', '"), solrCoreId);
            throw new TrainingSetException(msg, e);
        }
        return new Batch<Example>(items, nextExampleId != null, nextExampleId);
    }

    @Override
    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    @Override
    public void optimize() throws TrainingSetException {
        try {
            getActiveSolrServer().optimize();
        } catch (Exception e) {
            throw new TrainingSetException("Error optimizing training dataset " + getName(), e);
        }
    }

}
TOP

Related Classes of org.apache.stanbol.enhancer.topic.training.SolrTrainingSet

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.