更改elasticsearch的score评分
在某些情况下,我们需要自定义score的分值,从而达到个性化搜索的目的。例如我们通过机器学习可以得到每个用户的特征向量、同时知道每个商品的特征向量,如何计算这两个特征向量的相似度?这个两个特征向量越高,评分越高,从而把那些与用户相似度高的商品优先推荐给用户。
插件源码解读
通过查看官网文档,运行一个脚步必须通过“ScriptEngine”来实现的。为了开发一个自定义的插件,我们需要实现“ScriptEngine”接口,并通过getScriptEngine()这个方法来加载我们的插件。ScriptEngine接口具体介绍见文献[1].下面通过官网给出的一个具体例子:
private static class MyExpertScriptEngine implements ScriptEngine { //可以命名自己在脚本api中使用的名称来引用这个脚本后端。 @Override public String getType() { return "expert_scripts"; }
//核心方法,下面是通过java的lamada表达式来实现的 @Override public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) { if (context.equals(SearchScript.CONTEXT) == false) { throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } // we use the script "source" as the script identifier if ("pure_df".equals(scriptSource)) { //通过p来获取参数params中的值,lookup得到文档中的的值 SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { final String field; final String term; { if (p.containsKey("field") == false) { throw new IllegalArgumentException("Missing parameter [field]"); } if (p.containsKey("term") == false) { throw new IllegalArgumentException("Missing parameter [term]"); } field = p.get("field").toString(); term = p.get("term").toString(); } @Override public SearchScript newInstance(LeafReaderContext context) throws IOException { PostingsEnum postings = context.reader().postings(new Term(field, term)); if (postings == null) { // the field and/or term don't exist in this segment, so always return 0 return new SearchScript(p, lookup, context) { @Override public double runAsDouble() { return 0.0d; } }; } return new SearchScript(p, lookup, context) { int currentDocid = -1; @Override public void setDocument(int docid) { // advance has undefined behavior calling with a docid <= its current docid if (postings.docID() < docid) { try { postings.advance(docid); } catch (IOException e) { throw new UncheckedIOException(e); } } currentDocid = docid; } @Override public double runAsDouble() { if (postings.docID() != currentDocid) { // advance moved past the current doc, so this doc has no occurrences of the term return 0.0d; } try { return postings.freq(); } catch (IOException e) { throw new UncheckedIOException(e); } } }; } @Override public boolean needs_score() { return false; } }; return context.factoryClazz.cast(factory); } throw new IllegalArgumentException("Unknown script name " + scriptSource); } @Override public void close() { // optionally close resources } }
通过分析上面的代码及结合业务需求,我们给出如下脚步:
脚步一
package com; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.SearchScript; import java.io.IOException; import java.util.*; /** * \* Created with IntelliJ IDEA. * \* User: 0.0 * \* Date: 18-8-9 * \* Time: 下午2:32 * \* Description:为了得到个性化推荐搜索效果,我们计算用户向量与每个产品特征向量的相似度。 * 相似度越高,最后得到的分值越高,排序越靠前. * \ */ public class FeatureVectorScoreSearchScript implements ScriptEngine { private final static Logger logger = LogManager.getLogger(FeatureVectorScoreSearchScript.class); @Override public String getType() { return "feature_vector_scoring_script"; } @Override public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) { logger.info("The feature_vector_scoring_script is calculating the similarity of users and commodities"); if (!context.equals(SearchScript.CONTEXT)) { throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } if("whb_fvs".equals(scriptSource)) { SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { // 对入参检查 final Map<String, Object> inputFeatureVector; final String field; { if (p.containsKey("field") == false) { throw new IllegalArgumentException("Missing parameter [field]"); } if(p.containsKey("inputFeatureVector") == false){ throw new IllegalArgumentException("Missing parameter [inputFeatureVector]"); } field = p.get("field").toString(); inputFeatureVector = (Map<String,Object>) p.get("inputFeatureVector"); } @Override public SearchScript newInstance(LeafReaderContext context) throws IOException { return new SearchScript(p, lookup, context) { @Override public double runAsDouble() { if(lookup.source().containsKey(field)==true){ final Map<String, Double> productFeatureVector = (Map<String, Double>) lookup.source().get(field); return calculateVectorSimilarity(inputFeatureVector, productFeatureVector); }else { logger.info("The " + field + " is not exist in the product"); return 0.0D; } } }; } @Override public boolean needs_score() { return false; } }; return context.factoryClazz.cast(factory); }throw new IllegalArgumentException("Unknown script name " + scriptSource); } @Override public void close() { } //计算两个向量的相似度(cos) public double calculateVectorSimilarity(Map<String, Object> inputFeatureVector , Map<String,Double> productFeatureVector){ double sumOfProduct = 0.0D; double sumOfUser = 0.0D; double sumOfSquare = 0.0D; if(inputFeatureVector!=null && productFeatureVector!=null){ for(Map.Entry<String, Object> entry: inputFeatureVector.entrySet()){ String dimName = entry.getKey(); double dimScore = Double.parseDouble(entry.getValue().toString()); double itemDimScore = productFeatureVector.get(dimName); sumOfUser += dimScore*dimScore; sumOfProduct += itemDimScore*itemDimScore; sumOfSquare += dimScore*itemDimScore; } if(sumOfUser*sumOfProduct==0.0D){ return 0.0D; } return sumOfSquare / (Math.sqrt(sumOfUser)*Math.sqrt(sumOfProduct)); }else { return 0.0D; } } }
脚本二(fast-vector-distance)
/** * \* Created with IntelliJ IDEA. * \* User: 王火斌 * \* Date: 18-8-9 * \* Time: 下午2:32 * \* Description:为了得到个性化推荐搜索效果,我们计算用户向量与每个产品特征向量的相似度。 * 相似度越高,最后得到的分值越高,排序越靠前. * \ *//** package com; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.SearchScript; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.store.ByteArrayDataInput; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.util.*; * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update * src/main/resources/es-plugin.properties file that points to this class. */public final class FastVectorDistance extends Plugin implements ScriptPlugin { @Override public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) { return new FastVectorDistanceEngine(); } private static class FastVectorDistanceEngine implements ScriptEngine { private final static Logger logger = LogManager.getLogger(FastVectorDistance.class); private static final int DOUBLE_SIZE = 8; double queryVectorNorm; @Override public String getType() { return "feature_vector_scoring_script"; } @Override public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) { logger.info("The feature_vector_scoring_script is calculating the similarity of users and commodities"); if (!context.equals(SearchScript.CONTEXT)) { throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } if ("whb_fvd".equals(scriptSource)) { SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { // The field to compare against final String field; //Whether this search should be cosine or dot product final Boolean cosine; //The query embedded vector final Object vector; Boolean exclude; //The final comma delimited vector representation of the query vector double[] inputVector; { if (p.containsKey("field") == false) { throw new IllegalArgumentException("Missing parameter [field]"); } //Determine if cosine final Object cosineBool = p.get("cosine"); cosine = cosineBool != null ? (boolean) cosineBool : true; //Get the field value from the query field = p.get("field").toString(); final Object excludeBool = p.get("exclude"); exclude = excludeBool != null ? (boolean) cosineBool : true; //Get the query vector embedding vector = p.get("vector"); //Determine if raw comma-delimited vector or embedding was passed if (vector != null) { final ArrayList<Double> tmp = (ArrayList<Double>) vector; inputVector = new double[tmp.size()]; for (int i = 0; i < inputVector.length; i++) { inputVector[i] = tmp.get(i); } } else { final Object encodedVector = p.get("encoded_vector"); if (encodedVector == null) { throw new IllegalArgumentException("Must have 'vector' or 'encoded_vector' as a parameter"); } inputVector = Util.convertBase64ToArray((String) encodedVector); } //If cosine calculate the query vec norm if (cosine) { queryVectorNorm = 0d; // compute query inputVector norm once for (double v : inputVector) { queryVectorNorm += Math.pow(v, 2.0); } } } @Override public SearchScript newInstance(LeafReaderContext context) throws IOException { return new SearchScript(p, lookup, context) { Boolean is_value = false; // Use Lucene LeafReadContext to access binary values directly. BinaryDocValues accessor = context.reader().getBinaryDocValues(field); @Override public void setDocument(int docId) { // advance has undefined behavior calling with a docid <= its current docid try { accessor.advanceExact(docId); is_value = true; } catch (IOException e) { is_value = false; } } @Override public double runAsDouble() { //If there is no field value return 0 rather than fail. if (!is_value) return 0.0d; final int inputVectorSize = inputVector.length; final byte[] bytes; try { bytes = accessor.binaryValue().bytes; } catch (IOException e) { return 0d; } final ByteArrayDataInput byteDocVector = new ByteArrayDataInput(bytes); byteDocVector.readVInt(); final int docVectorLength = byteDocVector.readVInt(); // returns the number of bytes to read if (docVectorLength != inputVectorSize * DOUBLE_SIZE) { return 0d; } final int position = byteDocVector.getPosition(); final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, docVectorLength).asDoubleBuffer(); final double[] docVector = new double[inputVectorSize]; doubleBuffer.get(docVector); double docVectorNorm = 0d; double score = 0d; //calculate dot product of document vector and query vector for (int i = 0; i < inputVectorSize; i++) { score += docVector[i] * inputVector[i]; if (cosine) { docVectorNorm += Math.pow(docVector[i], 2.0); } } //If cosine, calcluate cosine score if (cosine) { if (docVectorNorm == 0 || queryVectorNorm == 0) return 0d; score = score / (Math.sqrt(docVectorNorm) * Math.sqrt(queryVectorNorm)); } return score; } }; } @Override public boolean needs_score() { return false; } }; return context.factoryClazz.cast(factory); } throw new IllegalArgumentException("Unknown script name " + scriptSource); } @Override public void close() {} } }
部署
通过maven来部署,具体部署步骤如下:
配置pom文件
加载依赖类,设置项目创建目录。
4.0.0
es-plugin
elasticsearch-plugin
1.0-SNAPSHOT<dependencies> <dependency> <groupId>org.elasticsearch</groupId> <artifactId>elasticsearch</artifactId> <version>6.1.1</version> </dependency> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>4.12</version> <scope>test</scope> </dependency> </dependencies> <build> <plugins> <plugin> <artifactId>maven-assembly-plugin</artifactId> <version>2.3</version> <configuration> <appendAssemblyId>false</appendAssemblyId> <outputDirectory>${project.build.directory}/releases/</outputDirectory> <descriptors> <descriptor>${basedir}/src/assembly/plugin.xml</descriptor> </descriptors> </configuration> <executions> <execution> <phase>package</phase> <goals> <goal>single</goal> </goals> </execution> </executions> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> <source>1.8</source> <target>1.8</target> </configuration> </plugin> </plugins> </build>
2.创建xml文件
<?xml version="1.0"?><assembly> <id>plugin</id> <formats> <format>zip</format> </formats> <includeBaseDirectory>false</includeBaseDirectory> <fileSets> <fileSet> <directory>${project.basedir}/src/main/resources</directory> <outputDirectory>feature-vector-score</outputDirectory> </fileSet> </fileSets> <dependencySets> <dependencySet> <outputDirectory>feature-vector-score</outputDirectory> <useProjectArtifact>true</useProjectArtifact> <useTransitiveFiltering>true</useTransitiveFiltering> <excludes> <exclude>org.elasticsearch:elasticsearch</exclude> <exclude>org.apache.logging.log4j:log4j-api</exclude> </excludes> </dependencySet> </dependencySets></assembly>
3.创建plugin-descriptor.properties文件
description=feature-vector-similarity version=1.0 name=feature-vector-score site=${elasticsearch.plugin.site} jvm=true classname=com.FeatureVectorScoreSearchPlugin java.version=1.8 elasticsearch.version=6.1.1 isolated=${elasticsearch.plugin.isolated}
description:simple summary of the plugin
version(String):plugin’s version
name(String):the plugin name
classname(String):the name of the class to load, fully-qualified.
java.version(String):version of java the code is built against. Use the system property java.specification.version. Version string must be a sequence of nonnegative decimal integers separated by "."'s and may have leading zeros.
测试
创建索引
create_index = { "settings": { "analysis": { "analyzer": { # this configures the custom analyzer we need to parse vectors such that the scoring # plugin will work correctly "payload_analyzer": { "type": "custom", "tokenizer":"whitespace", "filter":"delimited_payload_filter" } } } }, "mappings": { "movies": { # this mapping definition sets up the metadata fields for the movies "properties": { "movieId": { "type": "integer" }, "tmdbId": { "type": "keyword" }, "genres": { "type": "keyword" }, "release_date": { "type": "date", "format": "year" }, "@model": { # this mapping definition sets up the fields for movie factor vectors of our model "properties": { "factor": { "type": "binary", "doc_values": true }, "version": { "type": "keyword" }, "timestamp": { "type": "date" } } } }} }}
查询
You can execute the script by specifying its lang as expert_scripts, and the name of the script as the script source:
{ "query": { "function_score": { "query": { "match_all": { } }, "functions": [ { "script_score": { "script": { "source": "whb_fvd", "lang" : "feature_vector_scoring_script", "params": { "field": "@model.factor", "cosine": true, "encoded_vector" :"v9EUmGAAAAC/6f9VAAAAAL/j+OOgAAAAv+m6+oAAAAA/lTSDIAAAAL/FdkTAAAAAv7rKHKAAAAA/0iyEYAAAAD/ZUY6gAAAAP7TzYoAAAAA/1K4IAAAAAD+yH9XgAAAAv6QRBSAAAAA/vRiiwAAAAL/mRhzgAAAAv9WxpiAAAAC/8YD+QAAAAL/jpbtgAAAAv+zmD+AAAAC/1eqtIAAAAA==" } } } } ] } } }
版本说明
在最近一年中,es版本迭代速度很快,上述插件主要使用了SearchScript类适用于v5.4-v6.4。在esv5.4以下的版本,主要使用ExecutableScript类。对于es大于6.4版本,出现了一个新类ScoreScript来实现自定义评分脚本。
项目详细见github
https://github.com/SnailWhb/elasticsearch_pulgine_fast-vector-distance
参考文献
[1]https://static.javadoc.io/org.elasticsearch/elasticsearch/6.0.1/org/elasticsearch/script/ScriptEngine.html
[2]https://www.elastic.co/guide/en/elasticsearch/reference/current/modules-scripting-engine.html
[3]https://github.com/jiashiwen/elasticsearchpluginsample
[4]https://www.elastic.co/guide/en/elasticsearch/plugins/6.3/plugin-authors.html
作者:视野