继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

elasticsearch插件的开发--计算特征向量的相似度

交互式爱情
关注TA
已关注
手记 90
粉丝 23
获赞 75

更改elasticsearch的score评分

  在某些情况下,我们需要自定义score的分值,从而达到个性化搜索的目的。例如我们通过机器学习可以得到每个用户的特征向量、同时知道每个商品的特征向量,如何计算这两个特征向量的相似度?这个两个特征向量越高,评分越高,从而把那些与用户相似度高的商品优先推荐给用户。

插件源码解读

  通过查看官网文档,运行一个脚步必须通过“ScriptEngine”来实现的。为了开发一个自定义的插件,我们需要实现“ScriptEngine”接口,并通过getScriptEngine()这个方法来加载我们的插件。ScriptEngine接口具体介绍见文献[1].下面通过官网给出的一个具体例子:

  private static class MyExpertScriptEngine implements ScriptEngine {  //可以命名自己在脚本api中使用的名称来引用这个脚本后端。
    @Override
    public String getType() {        return "expert_scripts";
    }

 

  //核心方法,下面是通过java的lamada表达式来实现的
    @Override
    public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {        if (context.equals(SearchScript.CONTEXT) == false) {            throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
        }        // we use the script "source" as the script identifier
        if ("pure_df".equals(scriptSource)) {        //通过p来获取参数params中的值,lookup得到文档中的的值
            SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {                final String field;                final String term;
                {                    if (p.containsKey("field") == false) {                        throw new IllegalArgumentException("Missing parameter [field]");
                    }                    if (p.containsKey("term") == false) {                        throw new IllegalArgumentException("Missing parameter [term]");
                    }
                    field = p.get("field").toString();
                    term = p.get("term").toString();
                }                @Override
                public SearchScript newInstance(LeafReaderContext context) throws IOException {
                    PostingsEnum postings = context.reader().postings(new Term(field, term));                    if (postings == null) {                        // the field and/or term don't exist in this segment, so always return 0
                        return new SearchScript(p, lookup, context) {                            @Override
                            public double runAsDouble() {                                return 0.0d;
                            }
                        };
                    }                    return new SearchScript(p, lookup, context) {                        int currentDocid = -1;                        @Override
                        public void setDocument(int docid) {                            // advance has undefined behavior calling with a docid <= its current docid
                            if (postings.docID() < docid) {                                try {
                                    postings.advance(docid);
                                } catch (IOException e) {                                    throw new UncheckedIOException(e);
                                }
                            }
                            currentDocid = docid;
                        }                        @Override
                        public double runAsDouble() {                            if (postings.docID() != currentDocid) {                                // advance moved past the current doc, so this doc has no occurrences of the term
                                return 0.0d;
                            }                            try {                                return postings.freq();
                            } catch (IOException e) {                                throw new UncheckedIOException(e);
                            }
                        }
                    };
                }                @Override
                public boolean needs_score() {                    return false;
                }
            };            return context.factoryClazz.cast(factory);
        }        throw new IllegalArgumentException("Unknown script name " + scriptSource);
    }    @Override
    public void close() {        // optionally close resources
    }
}

通过分析上面的代码及结合业务需求,我们给出如下脚步:

脚步一

    package com;    
    import org.apache.logging.log4j.LogManager;    import org.apache.logging.log4j.Logger;    import org.apache.lucene.index.LeafReaderContext;    import org.elasticsearch.script.ScriptContext;    import org.elasticsearch.script.ScriptEngine;    import org.elasticsearch.script.SearchScript;    
    import java.io.IOException;    import java.util.*;    
    /**     * \* Created with IntelliJ IDEA.     * \* User: 0.0     * \* Date: 18-8-9     * \* Time: 下午2:32     * \* Description:为了得到个性化推荐搜索效果,我们计算用户向量与每个产品特征向量的相似度。     *          相似度越高,最后得到的分值越高,排序越靠前.     * \     */

    public class FeatureVectorScoreSearchScript implements ScriptEngine {
        private final static Logger logger = LogManager.getLogger(FeatureVectorScoreSearchScript.class);        @Override
        public String getType() {            return "feature_vector_scoring_script";
        }    @Override
    public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
        logger.info("The feature_vector_scoring_script is calculating the similarity of users and commodities");        if (!context.equals(SearchScript.CONTEXT)) {            throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
        }        if("whb_fvs".equals(scriptSource)) {
            SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {                // 对入参检查
                final Map<String, Object> inputFeatureVector;                final String field;
                {                    if (p.containsKey("field") == false) {                        throw new IllegalArgumentException("Missing parameter [field]");
                    }                    if(p.containsKey("inputFeatureVector") == false){                        throw new IllegalArgumentException("Missing parameter [inputFeatureVector]");
                    }
                    field = p.get("field").toString();
                    inputFeatureVector = (Map<String,Object>) p.get("inputFeatureVector");

                }                @Override
                public SearchScript newInstance(LeafReaderContext context) throws IOException {                    return new SearchScript(p, lookup, context) {                        @Override
                        public double runAsDouble() {                            if(lookup.source().containsKey(field)==true){                                final Map<String, Double> productFeatureVector = (Map<String, Double>) lookup.source().get(field);                                return calculateVectorSimilarity(inputFeatureVector, productFeatureVector);
                            }else {
                                logger.info("The " + field + " is not exist in the product");                                return 0.0D;
                            }
                        }
                    };
                }                @Override
                public boolean needs_score() {                    return false;
                }
            };            return context.factoryClazz.cast(factory);
        }throw new IllegalArgumentException("Unknown script name " + scriptSource);

    }    @Override
    public void close() {
    }    //计算两个向量的相似度(cos)
    public double calculateVectorSimilarity(Map<String, Object> inputFeatureVector , Map<String,Double> productFeatureVector){        double sumOfProduct = 0.0D;        double sumOfUser = 0.0D;        double sumOfSquare = 0.0D;        if(inputFeatureVector!=null && productFeatureVector!=null){            for(Map.Entry<String, Object> entry: inputFeatureVector.entrySet()){                String dimName = entry.getKey();                double dimScore = Double.parseDouble(entry.getValue().toString());                double itemDimScore = productFeatureVector.get(dimName);
                sumOfUser += dimScore*dimScore;
                sumOfProduct += itemDimScore*itemDimScore;
                sumOfSquare += dimScore*itemDimScore;
            }            if(sumOfUser*sumOfProduct==0.0D){                return 0.0D;
            }            return sumOfSquare / (Math.sqrt(sumOfUser)*Math.sqrt(sumOfProduct));
        }else {            return 0.0D;
        }
    }

    }

脚本二(fast-vector-distance)

/** * \* Created with IntelliJ IDEA. * \* User: 王火斌 * \* Date: 18-8-9 * \* Time: 下午2:32 * \* Description:为了得到个性化推荐搜索效果,我们计算用户向量与每个产品特征向量的相似度。
 *          相似度越高,最后得到的分值越高,排序越靠前.
 * \
 *//**
package com;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.SearchScript;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.store.ByteArrayDataInput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.util.*;

 * This class is instantiated when Elasticsearch loads the plugin for the
 * first time. If you change the name of this plugin, make sure to update
 * src/main/resources/es-plugin.properties file that points to this class.
 */public final class FastVectorDistance extends Plugin implements ScriptPlugin {    @Override
    public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {        return new FastVectorDistanceEngine();
    }

    private static class FastVectorDistanceEngine implements ScriptEngine {
        private final static Logger logger = LogManager.getLogger(FastVectorDistance.class);
        private static final int DOUBLE_SIZE = 8;        double queryVectorNorm;        @Override
        public String getType() {            return "feature_vector_scoring_script";
        }        @Override
        public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
            logger.info("The feature_vector_scoring_script is calculating the similarity of users and commodities");            if (!context.equals(SearchScript.CONTEXT)) {                throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
            }            if ("whb_fvd".equals(scriptSource)) {
                SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {                    // The field to compare against
                    final String field;                    //Whether this search should be cosine or dot product
                    final Boolean cosine;                    //The query embedded vector
                    final Object vector;
                    Boolean exclude;                    //The final comma delimited vector representation of the query vector
                    double[] inputVector;

                    {                        if (p.containsKey("field") == false) {                            throw new IllegalArgumentException("Missing parameter [field]");
                        }                        //Determine if cosine
                        final Object cosineBool = p.get("cosine");
                        cosine = cosineBool != null ? (boolean) cosineBool : true;                        //Get the field value from the query
                        field = p.get("field").toString();                        final Object excludeBool = p.get("exclude");
                        exclude = excludeBool != null ? (boolean) cosineBool : true;                        //Get the query vector embedding
                        vector = p.get("vector");                        //Determine if raw comma-delimited vector or embedding was passed
                        if (vector != null) {                            final ArrayList<Double> tmp = (ArrayList<Double>) vector;
                            inputVector = new double[tmp.size()];                            for (int i = 0; i < inputVector.length; i++) {
                                inputVector[i] = tmp.get(i);
                            }
                        } else {                            final Object encodedVector = p.get("encoded_vector");                            if (encodedVector == null) {                                throw new IllegalArgumentException("Must have 'vector' or 'encoded_vector' as a parameter");
                            }
                            inputVector = Util.convertBase64ToArray((String) encodedVector);
                        }                        //If cosine calculate the query vec norm
                        if (cosine) {
                            queryVectorNorm = 0d;                            // compute query inputVector norm once
                            for (double v : inputVector) {
                                queryVectorNorm += Math.pow(v, 2.0);
                            }
                        }
                    }                    @Override
                    public SearchScript newInstance(LeafReaderContext context) throws IOException {                        return new SearchScript(p, lookup, context) {
                            Boolean is_value = false;                            // Use Lucene LeafReadContext to access binary values directly.
                            BinaryDocValues accessor = context.reader().getBinaryDocValues(field);                            @Override
                            public void setDocument(int docId) {                                // advance has undefined behavior calling with a docid <= its current docid
                                try {
                                    accessor.advanceExact(docId);
                                    is_value = true;
                                } catch (IOException e) {
                                    is_value = false;
                                }
                            }                            @Override
                            public double runAsDouble() {                                //If there is no field value return 0 rather than fail.
                                if (!is_value) return 0.0d;                                final int inputVectorSize = inputVector.length;                                final byte[] bytes;                                try {
                                    bytes = accessor.binaryValue().bytes;
                                } catch (IOException e) {                                    return 0d;
                                }                                final ByteArrayDataInput byteDocVector = new ByteArrayDataInput(bytes);

                                byteDocVector.readVInt();                                final int docVectorLength = byteDocVector.readVInt(); // returns the number of bytes to read

                                if (docVectorLength != inputVectorSize * DOUBLE_SIZE) {                                    return 0d;
                                }                                final int position = byteDocVector.getPosition();                                final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, docVectorLength).asDoubleBuffer();                                final double[] docVector = new double[inputVectorSize];

                                doubleBuffer.get(docVector);                                double docVectorNorm = 0d;                                double score = 0d;                                //calculate dot product of document vector and query vector
                                for (int i = 0; i < inputVectorSize; i++) {

                                    score += docVector[i] * inputVector[i];                                    if (cosine) {
                                        docVectorNorm += Math.pow(docVector[i], 2.0);
                                    }
                                }                                //If cosine, calcluate cosine score
                                if (cosine) {                                    if (docVectorNorm == 0 || queryVectorNorm == 0) return 0d;

                                    score = score / (Math.sqrt(docVectorNorm) * Math.sqrt(queryVectorNorm));
                                }                                return score;
                            }
                        };
                    }                    @Override
                    public boolean needs_score() {                        return false;
                    }
                };                return context.factoryClazz.cast(factory);
            }            throw new IllegalArgumentException("Unknown script name " + scriptSource);
        }        @Override
        public void close() {}
    }
}

部署

通过maven来部署,具体部署步骤如下:

  1. 配置pom文件
    加载依赖类,设置项目创建目录。



    4.0.0
    es-plugin
    elasticsearch-plugin
    1.0-SNAPSHOT

     <dependencies>
         <dependency>
             <groupId>org.elasticsearch</groupId>
             <artifactId>elasticsearch</artifactId>
             <version>6.1.1</version>
         </dependency>
         <dependency>
             <groupId>junit</groupId>
             <artifactId>junit</artifactId>
             <version>4.12</version>
             <scope>test</scope>
         </dependency>
     </dependencies>
     <build>
         <plugins>
             <plugin>
                 <artifactId>maven-assembly-plugin</artifactId>
                 <version>2.3</version>
                 <configuration>
                     <appendAssemblyId>false</appendAssemblyId>
                     <outputDirectory>${project.build.directory}/releases/</outputDirectory>
                     <descriptors>
                         <descriptor>${basedir}/src/assembly/plugin.xml</descriptor>
                     </descriptors>
                 </configuration>
                 <executions>
                     <execution>
                         <phase>package</phase>
                         <goals>
                             <goal>single</goal>
                         </goals>
                     </execution>
                 </executions>
             </plugin>
             <plugin>
                 <groupId>org.apache.maven.plugins</groupId>
                 <artifactId>maven-compiler-plugin</artifactId>
                 <configuration>
                     <source>1.8</source>
                     <target>1.8</target>
                 </configuration>
             </plugin>
         </plugins>
     </build>


2.创建xml文件

<?xml version="1.0"?><assembly>
    <id>plugin</id>
    <formats>
        <format>zip</format>
    </formats>
    <includeBaseDirectory>false</includeBaseDirectory>
    <fileSets>
        <fileSet>
            <directory>${project.basedir}/src/main/resources</directory>
            <outputDirectory>feature-vector-score</outputDirectory>
        </fileSet>
    </fileSets>
    <dependencySets>
        <dependencySet>
            <outputDirectory>feature-vector-score</outputDirectory>
            <useProjectArtifact>true</useProjectArtifact>
            <useTransitiveFiltering>true</useTransitiveFiltering>
            <excludes>
                <exclude>org.elasticsearch:elasticsearch</exclude>
                <exclude>org.apache.logging.log4j:log4j-api</exclude>
            </excludes>
        </dependencySet>
    </dependencySets></assembly>

3.创建plugin-descriptor.properties文件

description=feature-vector-similarity
version=1.0
name=feature-vector-score
site=${elasticsearch.plugin.site}
jvm=true
classname=com.FeatureVectorScoreSearchPlugin
java.version=1.8
elasticsearch.version=6.1.1
isolated=${elasticsearch.plugin.isolated}

description:simple summary of the plugin
version(String):plugin’s version
name(String):the plugin name
classname(String):the name of the class to load, fully-qualified.
java.version(String):version of java the code is built against. Use the system property java.specification.version. Version string must be a sequence of nonnegative decimal integers separated by "."'s and may have leading zeros.

测试

创建索引

create_index = {    "settings": {        "analysis": {            "analyzer": {                # this configures the custom analyzer we need to parse vectors such that the scoring
                # plugin will work correctly
                "payload_analyzer": {                    "type": "custom",                    "tokenizer":"whitespace",                    "filter":"delimited_payload_filter"
                }
            }
        }
    },    "mappings": {           "movies": {            # this mapping definition sets up the metadata fields for the movies
            "properties": {                "movieId": {                    "type": "integer"
                },                "tmdbId": {                    "type": "keyword"
                },                "genres": {                    "type": "keyword"
                },                "release_date": {                    "type": "date",                    "format": "year"
                },                "@model": {                    # this mapping definition sets up the fields for movie factor vectors of our model
                    "properties": {                        "factor": {                            "type": "binary",                            "doc_values": true
                        },                        "version": {                            "type": "keyword"
                        },                        "timestamp": {                            "type": "date"
                        }
                    }
                }
            }}
}}

查询

You can execute the script by specifying its lang as expert_scripts, and the name of the script as the script source:

{  "query": {     
     "function_score": {      "query": {        "match_all": {  
        }
      },        "functions": [
          {            "script_score": {              "script": {                  "source": "whb_fvd",                  "lang" : "feature_vector_scoring_script",                  "params": {                      "field": "@model.factor",                      "cosine": true,                      "encoded_vector" :"v9EUmGAAAAC/6f9VAAAAAL/j+OOgAAAAv+m6+oAAAAA/lTSDIAAAAL/FdkTAAAAAv7rKHKAAAAA/0iyEYAAAAD/ZUY6gAAAAP7TzYoAAAAA/1K4IAAAAAD+yH9XgAAAAv6QRBSAAAAA/vRiiwAAAAL/mRhzgAAAAv9WxpiAAAAC/8YD+QAAAAL/jpbtgAAAAv+zmD+AAAAC/1eqtIAAAAA==" 
                  }
              }
            }
          }
        ]
    }
  }
}

版本说明

在最近一年中,es版本迭代速度很快,上述插件主要使用了SearchScript类适用于v5.4-v6.4。在esv5.4以下的版本,主要使用ExecutableScript类。对于es大于6.4版本,出现了一个新类ScoreScript来实现自定义评分脚本。

项目详细见github

https://github.com/SnailWhb/elasticsearch_pulgine_fast-vector-distance

参考文献

[1]https://static.javadoc.io/org.elasticsearch/elasticsearch/6.0.1/org/elasticsearch/script/ScriptEngine.html
[2]https://www.elastic.co/guide/en/elasticsearch/reference/current/modules-scripting-engine.html
[3]https://github.com/jiashiwen/elasticsearchpluginsample
[4]https://www.elastic.co/guide/en/elasticsearch/plugins/6.3/plugin-authors.html

作者:视野

原文出处: https://www.cnblogs.com/whb-20160329/p/10472717.html  

打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP