diff --git a/face-search-client/src/main/java/com/visual/face/search/handle/CollectHandler.java b/face-search-client/src/main/java/com/visual/face/search/handle/CollectHandler.java index 5248f9f7ae55207ceabbfd8c5e6ecedd3e46f310..d3f6ce130a14fb4beaa4cacdd2b5c27aaa602f82 100755 --- a/face-search-client/src/main/java/com/visual/face/search/handle/CollectHandler.java +++ b/face-search-client/src/main/java/com/visual/face/search/handle/CollectHandler.java @@ -48,7 +48,8 @@ public class CollectHandler extends BaseHandler{ .setFaceColumns(collect.getFaceColumns()) .setShardsNum(collect.getShardsNum()) .setStorageFaceInfo(collect.getStorageFaceInfo()) - .setStorageEngine(collect.getStorageEngine()); + .setStorageEngine(collect.getStorageEngine()) + .setApproximateKnn(collect.isApproximateKnn()); return HttpClient.post(Api.getUrl(this.serverHost, Api.collect_create), collectReq); } diff --git a/face-search-client/src/main/java/com/visual/face/search/handle/SearchHandler.java b/face-search-client/src/main/java/com/visual/face/search/handle/SearchHandler.java index 06feb3ce3d63cc2c5963f7f77fe2a237cb39a287..57af5bc64ff44954647a011c8594cf6ca5fc7d17 100755 --- a/face-search-client/src/main/java/com/visual/face/search/handle/SearchHandler.java +++ b/face-search-client/src/main/java/com/visual/face/search/handle/SearchHandler.java @@ -45,7 +45,8 @@ public class SearchHandler extends BaseHandler{ .setMaxFaceNum(search.getMaxFaceNum()) .setLimit(search.getLimit()) .setConfidenceThreshold(search.getConfidenceThreshold()) - .setFaceScoreThreshold(search.getFaceScoreThreshold()); + .setFaceScoreThreshold(search.getFaceScoreThreshold()) + .setApproximateKnn(search.isApproximateKnn()); return HttpClient.post(Api.getUrl(this.serverHost, Api.visual_search), searchReq, new TypeReference>>(){}); } } diff --git a/face-search-client/src/main/java/com/visual/face/search/model/Collect.java b/face-search-client/src/main/java/com/visual/face/search/model/Collect.java index d08eb54a34ab873ced7eb602da63e6772b63279f..5323eacafae84ee4d262a0b74d848d3e8313403b 100755 --- a/face-search-client/src/main/java/com/visual/face/search/model/Collect.java +++ b/face-search-client/src/main/java/com/visual/face/search/model/Collect.java @@ -23,6 +23,9 @@ public class Collect> implements Serializab /**保留图片及人脸信息的存储组件**/ private StorageEngine storageEngine; + /**是否启用近似knn搜索**/ + private boolean approximateKnn; + /** * 构建集合对象 * @return @@ -51,6 +54,15 @@ public class Collect> implements Serializab return (ExtendsVo) this; } + public boolean isApproximateKnn(){ + return this.approximateKnn; + } + + public ExtendsVo setApproximateKnn(boolean approximateKnn){ + this.approximateKnn = approximateKnn; + return (ExtendsVo) this; + } + public Integer getShardsNum() { return shardsNum; } diff --git a/face-search-client/src/main/java/com/visual/face/search/model/Search.java b/face-search-client/src/main/java/com/visual/face/search/model/Search.java index 9042a4515185abeab51b48b8698188d298e4814b..aca13a642286398236a2085f8ad6b982a7e1f6be 100755 --- a/face-search-client/src/main/java/com/visual/face/search/model/Search.java +++ b/face-search-client/src/main/java/com/visual/face/search/model/Search.java @@ -15,6 +15,8 @@ public class Search> implements Serializable /**对输入图像中多少个人脸进行检索比对:默认5**/ private Integer maxFaceNum=5; + /**是否启用近似knn搜索**/ + private boolean approximateKnn = false; /** * 构建检索对象 * @param imageBase64 待检索的图片 @@ -33,6 +35,15 @@ public class Search> implements Serializable return (ExtendsVo) this; } + public boolean isApproximateKnn() { + return approximateKnn; + } + + public ExtendsVo setApproximateKnn(boolean approximateKnn) { + this.approximateKnn = approximateKnn; + return (ExtendsVo)this; + } + public Float getFaceScoreThreshold() { return faceScoreThreshold; } diff --git a/face-search-engine/src/main/java/com/visual/face/search/engine/api/SearchEngine.java b/face-search-engine/src/main/java/com/visual/face/search/engine/api/SearchEngine.java index c2f234d69d2df2cb0334b2d3f4dd20cadccc7675..298f64a414658b07d7c25e6365e3d8e2814c7e08 100755 --- a/face-search-engine/src/main/java/com/visual/face/search/engine/api/SearchEngine.java +++ b/face-search-engine/src/main/java/com/visual/face/search/engine/api/SearchEngine.java @@ -14,7 +14,11 @@ public interface SearchEngine { public boolean dropCollection(String collectionName); - public boolean createCollection(String collectionName, MapParam param); + default public boolean createCollection(String collectionName, MapParam param){ + return createCollection(collectionName,param,false); + }; + + public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn); public boolean insertVector(String collectionName, String sampleId, String faceId, float[] vectors); @@ -22,7 +26,11 @@ public interface SearchEngine { public boolean deleteVectorByKey(String collectionName, List faceIds); - public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK); + default public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK){ + return search(collectionName,features,algorithm,topK,false); + }; + + public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK, boolean approximateKnn); public float searchMinScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm); diff --git a/face-search-engine/src/main/java/com/visual/face/search/engine/conf/Constant.java b/face-search-engine/src/main/java/com/visual/face/search/engine/conf/Constant.java index 9144894ba1e6dc42cb2b016cdffc1e55c579ffb7..7196ebd4dc2a40c7a5b7b292853b0fe85d4629f8 100755 --- a/face-search-engine/src/main/java/com/visual/face/search/engine/conf/Constant.java +++ b/face-search-engine/src/main/java/com/visual/face/search/engine/conf/Constant.java @@ -4,6 +4,7 @@ public class Constant { public final static String IndexShardsNum = "shardsNum"; public final static String IndexReplicasNum = "replicasNum"; + public final static String IndexAlgoParamEfSearch = "algoParamEfSearch"; public final static String ColumnNameFaceId = "face_id"; public final static String ColumnNameSampleId = "sample_id"; diff --git a/face-search-engine/src/main/java/com/visual/face/search/engine/impl/OpenSearchEngine.java b/face-search-engine/src/main/java/com/visual/face/search/engine/impl/OpenSearchEngine.java index aefe96a581a8eee5159d1ab9630b7a5467cae6e9..ab026a19fb3b12a8aa24b8e4e1d4c72c103cd36a 100644 --- a/face-search-engine/src/main/java/com/visual/face/search/engine/impl/OpenSearchEngine.java +++ b/face-search-engine/src/main/java/com/visual/face/search/engine/impl/OpenSearchEngine.java @@ -3,6 +3,7 @@ package com.visual.face.search.engine.impl; import com.visual.face.search.engine.api.SearchEngine; import com.visual.face.search.engine.conf.Constant; import com.visual.face.search.engine.exps.SearchEngineException; +import com.visual.face.search.engine.impl.query.ApproximateKnnQueryBuilder; import com.visual.face.search.engine.model.*; import org.apache.commons.collections4.MapUtils; import org.opensearch.action.DocWriteResponse; @@ -71,17 +72,32 @@ public class OpenSearchEngine implements SearchEngine { } @Override - public boolean createCollection(String collectionName, MapParam param) { + public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn) { try { //构建请求 CreateIndexRequest createIndexRequest = new CreateIndexRequest(collectionName); - createIndexRequest.settings(Settings.builder() - .put("index.number_of_shards", param.getIndexShardsNum()) - .put("index.number_of_replicas", param.getIndexReplicasNum()) - ); + Settings.Builder builder = Settings.builder() + .put("index.number_of_shards", param.getIndexShardsNum()) + .put("index.number_of_replicas", param.getIndexReplicasNum()); + if(approximateKnn){ + //启用open search近似knn搜索支持 + builder.put("index.knn",true); + builder.put("index.knn.algo_param.ef_search",param.getIndexAlgoParamEfSearch()); + } + + createIndexRequest.settings(builder); HashMap properties = new HashMap<>(); properties.put(Constant.ColumnNameSampleId, Map.of("type", "keyword")); - properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512")); + if(approximateKnn){ + //启用open search近似knn搜索支持 + properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512", + "method",Map.of("engine","nmslib", + "space_type","cosinesimil", + "name","hnsw", + "parameters",Map.of()))); + }else { + properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512")); + } createIndexRequest.mapping(Map.of("properties", properties)); //创建集合 CreateIndexResponse createIndexResponse = client.indices().create(createIndexRequest, RequestOptions.DEFAULT); @@ -135,23 +151,37 @@ public class OpenSearchEngine implements SearchEngine { } @Override - public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK) { + public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK,boolean approximateKnn) { try { //构建搜索请求 MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); for(float[] feature : features){ - QueryBuilder queryBuilder = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - params.put("field", Constant.ColumnNameFaceVector); - params.put("space_type", algorithm); - params.put("query_value", feature); - Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params); - ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(scriptScoreQueryBuilder).size(topK) - .fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段 - SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder); - multiSearchRequest.add(searchRequest); + if(approximateKnn){ + //近似knn搜索 + Map params = new HashMap<>(); + params.put("vector",feature); + params.put("k",topK); + ApproximateKnnQueryBuilder approximateKnnQueryBuilder = new ApproximateKnnQueryBuilder(params); + SearchSourceBuilder searchSourceBuilder =new SearchSourceBuilder() + .query(approximateKnnQueryBuilder).size(topK) + .fetchSource(null,Constant.ColumnNameFaceVector); + SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder); + multiSearchRequest.add(searchRequest); + }else { + //常规搜索 + QueryBuilder queryBuilder = new MatchAllQueryBuilder(); + Map params = new HashMap<>(); + params.put("field", Constant.ColumnNameFaceVector); + params.put("space_type", algorithm); + params.put("query_value", feature); + Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params); + ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(scriptScoreQueryBuilder).size(topK) + .fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段 + SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder); + multiSearchRequest.add(searchRequest); + } } //查询索引 MultiSearchResponse response = this.client.msearch(multiSearchRequest, RequestOptions.DEFAULT); @@ -167,7 +197,7 @@ public class OpenSearchEngine implements SearchEngine { if(searchHits != null){ for(SearchHit searchHit : searchHits){ String faceId = searchHit.getId(); - float score = searchHit.getScore()-1; + float score = approximateKnn? searchHit.getScore() : (searchHit.getScore()-1); Map sourceMap = searchHit.getSourceAsMap(); String sampleId = MapUtils.getString(sourceMap, Constant.ColumnNameSampleId); Object faceVector = MapUtils.getObject(sourceMap, Constant.ColumnNameFaceVector); diff --git a/face-search-engine/src/main/java/com/visual/face/search/engine/impl/query/ApproximateKnnQueryBuilder.java b/face-search-engine/src/main/java/com/visual/face/search/engine/impl/query/ApproximateKnnQueryBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..42b4ada94776a268095e0e4852c37dacbbec029e --- /dev/null +++ b/face-search-engine/src/main/java/com/visual/face/search/engine/impl/query/ApproximateKnnQueryBuilder.java @@ -0,0 +1,58 @@ +package com.visual.face.search.engine.impl.query; + +import com.visual.face.search.engine.conf.Constant; +import org.apache.lucene.search.Query; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryShardContext; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +/** + * 近似KNN搜索 + * @Author Foy Lian + * @Date 2023/9/13 14:26 + **/ + +public class ApproximateKnnQueryBuilder extends AbstractQueryBuilder { + + private Map mParams; + + public ApproximateKnnQueryBuilder(Map params){ + this.mParams = params; + } + + @Override + protected void doWriteTo(StreamOutput streamOutput) throws IOException { + } + + @Override + protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + xContentBuilder.startObject("knn"); + xContentBuilder.field(Constant.ColumnNameFaceVector,mParams); + xContentBuilder.endObject(); + } + + @Override + protected Query doToQuery(QueryShardContext queryShardContext) throws IOException { + return null; + } + + @Override + protected boolean doEquals(ApproximateKnnQueryBuilder approximateKnnQueryBuilder) { + return Objects.equals(this.mParams, approximateKnnQueryBuilder.mParams); + } + + @Override + protected int doHashCode() { + return Objects.hash(new Object[]{this.mParams}); + } + + @Override + public String getWriteableName() { + return "knn"; + } +} diff --git a/face-search-engine/src/main/java/com/visual/face/search/engine/model/MapParam.java b/face-search-engine/src/main/java/com/visual/face/search/engine/model/MapParam.java index 595adfa0b76ac791d884aaa36c22d196158249b2..0074fb660777e484e56727254a665bcd26580fa3 100755 --- a/face-search-engine/src/main/java/com/visual/face/search/engine/model/MapParam.java +++ b/face-search-engine/src/main/java/com/visual/face/search/engine/model/MapParam.java @@ -70,5 +70,10 @@ public class MapParam extends ConcurrentHashMap { shardsNum = (null == shardsNum || shardsNum <= 0) ? 4 : shardsNum; return shardsNum; } - + + public Integer getIndexAlgoParamEfSearch(){ + Integer algoParamEfSearch = this.getInteger(Constant.IndexAlgoParamEfSearch, 512); + algoParamEfSearch = (null == algoParamEfSearch || algoParamEfSearch <= 0) ? 512 : algoParamEfSearch; + return algoParamEfSearch; + } } diff --git a/face-search-server/src/main/java/com/visual/face/search/server/domain/base/CollectVo.java b/face-search-server/src/main/java/com/visual/face/search/server/domain/base/CollectVo.java index 874640d81855a88b0b75fc21f7f19473d604ee54..c8a4e75a99dd570e4ce54a65a166b08105e23035 100755 --- a/face-search-server/src/main/java/com/visual/face/search/server/domain/base/CollectVo.java +++ b/face-search-server/src/main/java/com/visual/face/search/server/domain/base/CollectVo.java @@ -46,6 +46,11 @@ public class CollectVo> extends BaseVo { @ApiModelProperty(value="保留图片及人脸信息的存储组件", position = 9,required = false) private StorageEngine storageEngine; + /**是否启用近似knn搜索**/ + @ApiModelProperty(value="是否启用近似knn搜索", position = 10,required = false) + private boolean approximateKnn; + + /** * 构建集合对象 * @param namespace 命名空间 @@ -56,6 +61,15 @@ public class CollectVo> extends BaseVo { return new CollectVo().setNamespace(namespace).setCollectionName(collectionName); } + public boolean isApproximateKnn() { + return approximateKnn; + } + + public ExtendsVo setApproximateKnn(boolean approximateKnn){ + this.approximateKnn = approximateKnn; + return (ExtendsVo) this; + } + public String getNamespace() { return namespace; } @@ -144,4 +158,5 @@ public class CollectVo> extends BaseVo { } return (ExtendsVo) this; } + } diff --git a/face-search-server/src/main/java/com/visual/face/search/server/domain/request/FaceSearchReqVo.java b/face-search-server/src/main/java/com/visual/face/search/server/domain/request/FaceSearchReqVo.java index 7d0a917ba15fa0faceaa2b27f7b781b4e67d359d..2c5f241a8fc155c542b6cd7f701ad456f50dcd68 100755 --- a/face-search-server/src/main/java/com/visual/face/search/server/domain/request/FaceSearchReqVo.java +++ b/face-search-server/src/main/java/com/visual/face/search/server/domain/request/FaceSearchReqVo.java @@ -42,6 +42,12 @@ public class FaceSearchReqVo extends BaseVo { @ApiModelProperty(value="对输入图像中多少个人脸进行检索比对:默认5", position = 7, required = false) private Integer maxFaceNum; + /**是否使用近似knn搜索**/ + @ApiModelProperty(value="是否使用近似knn搜索", position = 8, required = false) + private boolean approximateKnn; + + + /** * 构建检索对象 * @param namespace 命名空间 @@ -130,4 +136,13 @@ public class FaceSearchReqVo extends BaseVo { this.maxFaceNum = maxFaceNum; return this; } + + public boolean isApproximateKnn() { + return approximateKnn; + } + + public FaceSearchReqVo setApproximateKnn(boolean approximateKnn) { + this.approximateKnn = approximateKnn; + return this; + } } diff --git a/face-search-server/src/main/java/com/visual/face/search/server/service/impl/CollectServiceImpl.java b/face-search-server/src/main/java/com/visual/face/search/server/service/impl/CollectServiceImpl.java index 7c5f6fdbd48b935b624dfb1e7fc736cf6cfc2023..7d5a1b5094de4c267804a3de5e965a40f915f760 100755 --- a/face-search-server/src/main/java/com/visual/face/search/server/service/impl/CollectServiceImpl.java +++ b/face-search-server/src/main/java/com/visual/face/search/server/service/impl/CollectServiceImpl.java @@ -90,7 +90,7 @@ public class CollectServiceImpl extends BaseService implements CollectService { MapParam param = MapParam.build() .put(Constant.IndexShardsNum, collect.getShardsNum()) .put(Constant.IndexReplicasNum, collect.getReplicasNum()); - boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param); + boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param,collect.isApproximateKnn()); if(!createVectorFlag){ throw new RuntimeException("create vector table error"); } diff --git a/face-search-server/src/main/java/com/visual/face/search/server/service/impl/FaceSearchServiceImpl.java b/face-search-server/src/main/java/com/visual/face/search/server/service/impl/FaceSearchServiceImpl.java index 968ffa45e5d7fad594b4d9ec9f1ba765272b8e65..628a1dc74d1f94200a83b297fc507d6fa6089338 100755 --- a/face-search-server/src/main/java/com/visual/face/search/server/service/impl/FaceSearchServiceImpl.java +++ b/face-search-server/src/main/java/com/visual/face/search/server/service/impl/FaceSearchServiceImpl.java @@ -83,7 +83,7 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ } //特征搜索 int topK = (null == search.getLimit() || search.getLimit() <= 0) ? 5 : search.getLimit(); - SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK); + SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK,search.isApproximateKnn()); if(!searchResponse.getStatus().ok()){ throw new RuntimeException(searchResponse.getStatus().getReason()); } diff --git a/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java b/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java index 1c6b5ec38cb73288356e7e6208c363709af0c629..c64249111a8c52251150ec35b5d87e4d18f07d16 100644 --- a/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java +++ b/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java @@ -26,6 +26,12 @@ public class FaceSearchExample { public static String collectionName = "collect_20211201_v11"; public static FaceSearch faceSearch = FaceSearch.build(serverHost, namespace, collectionName); + //是否启用近似knn,建议底库集比较大时启用. + public static boolean approximateKnn = false; + + //底库集比较大时,建议调大,如:32个分片 + public static int shardsNum = 4; + /**集合创建*/ public static void collect(){ //样本属性字段 @@ -44,7 +50,12 @@ public class FaceSearchExample { //是否保存人脸及图片数据信息 .setStorageFaceInfo(true) //目前只实现了数据库存储,对其他类型存储实现StorageImageService接口即可 - .setStorageEngine(StorageEngine.CURR_DB); + .setStorageEngine(StorageEngine.CURR_DB) + //设置分片大小 + .setShardsNum(shardsNum) + //开启关闭近似knn搜索 + .setApproximateKnn(approximateKnn); + //删除集合 Response deleteCollect = faceSearch.collect().deleteCollect(); System.out.println(deleteCollect); @@ -97,6 +108,9 @@ public class FaceSearchExample { .search(Search.build(imageBase64) .setConfidenceThreshold(50f) //最小置信分:50 .setMaxFaceNum(10).setLimit(1) + //collect()创建集合时即使开了近似knn搜索,这里设置为false也可以使用精确knn搜索。 + //这里数据量足够大时,就能发现近似knn返回较快 + .setApproximateKnn(approximateKnn) ); Long e = System.currentTimeMillis(); System.out.println("search cost:" + (e-s)+"ms");