Elasticsearch 递归拼接聚合条件以及获取聚合值

3,214 阅读5分钟

基于 RestHighLevelClient 6.8.1,Elasticsearch 6.8.1

由于 Elasticsearch 的 DSL 拼接有点小恶心,和平常使用 sql 习惯不太相同,有想法开源出来和大家探讨

主要maven依赖

<dependency>
  <groupId>org.elasticsearch.client</groupId>
  <artifactId>elasticsearch-rest-high-level-client</artifactId>
  <version>6.8.1</version>
  <exclusions>
    <exclusion>
      <groupId>org.elasticsearch</groupId>
      <artifactId>elasticsearch</artifactId>
    </exclusion>
  </exclusions>
</dependency>

<dependency>
  <groupId>org.elasticsearch</groupId>
  <artifactId>elasticsearch</artifactId>
  <version>6.8.1</version>
</dependency>

需要引入的依赖

import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.cardinality.Cardinality;
import org.elasticsearch.search.aggregations.metrics.cardinality.CardinalityAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.sum.Sum;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.springframework.util.CollectionUtils;

import javax.annotation.Resource;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.*;

测试代码

保存 Elasticsearch 连接的 bean

@Component
public class ElasticsearchClientHolder implements InitializingBean, DisposableBean {
  private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchClientHolder.class);

  private RestHighLevelClient client;

  public RestHighLevelClient getClient() {
    return client;
  }

  @Resource
  private PropertiesHolder propertiesHolder;

  @Override
  public void afterPropertiesSet() {
    final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
    credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(propertiesHolder.getProperty("es.username"), propertiesHolder.getProperty("es.password")));
    RestClientBuilder builder = RestClient.builder(getHost())
      .setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider));
    client = new RestHighLevelClient(builder);
    LOGGER.info("连接 Elasticsearch 成功");
  }

  private HttpHost[] getHost() {
    String property = propertiesHolder.getProperty("es.hosts");
    String[] split = property.split(",");
    HttpHost[] httpHosts = new HttpHost[split.length];
    for (int i = 0; i < split.length; i++) {
      httpHosts[i] = new HttpHost(split[i],
                                  Integer.parseInt(propertiesHolder.getProperty("es.port")),
                                  "http");
    }
    return httpHosts;
  }

  @Override
  public void destroy() {
    if (client != null) {
      try {
        client.close();
      } catch (IOException ignored) {
      }
    }
  }
}
public class ElasticsearchDemo {
  @Resource
  private ElasticsearchClientHolder elasticsearchClientHolder;

  /**
     * 需要将此 map 初始化,key 为字段名,value 为 Elasticsearch 的脚本
     **/
  private Map<String, String> map = new HashMap<>();

  /**
   * 聚合查询
   *
   * @param groupBySet   需要进行groupBy的字段
   * @param sumFieldList 需要进行sum操作的字段
   * @param cardinality  需要进行count(distinct ..)操作的字段
   * @param docCount     需要进行count(1)操作的字段
   * @author fageiguanbing
   * @date 2019/10/18
   **/
  private void query(Set<String> groupBySet, List<String> sumFieldList, String cardinality, String docCount) throws IOException {
    // 查询条件
    BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery();

    // 此处对应 equal 操作
    queryBuilder.filter(QueryBuilders.termQuery(key, value));
    // 此处对应 in 操作
    queryBuilder.filter(QueryBuilders.termsQuery(key, list));


    // 日期 range 查询
    RangeQueryBuilder rangeQueryBuilder = QueryBuilders.
      rangeQuery("要查询的属性")
      .from("开始时间")
      .to("结束时间")
      .format("yyyy-MM-dd HH:mm:ss");
    queryBuilder.filter(rangeQueryBuilder);


    // 查询与汇总 builder
    SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
      // 查询分页参数,不是聚合分页参数,所以设置为0
      .from(0)
      .size(0)
      .query(queryBuilder);

    List<String> groupByList = new ArrayList<>(groupBySet);
    // 无 group by 查询
    if (CollectionUtils.isEmpty(groupByList)) {
      for (String sum : sumFieldList) {
        sourceBuilder.aggregation(
          AggregationBuilders
          // 名称,取返回值时用此名称
          .sum(sum)
          // 这个script可以使用其他手段(可从数据源获取,可从配置文件读取等)
          .script(new Script(map.get(sum)))
        );
      }
      // count(distinct ...)
      if (StringUtils.isNotBlank(cardinality)) {
        CardinalityAggregationBuilder cardinalityAggregationBuilder = AggregationBuilders
          .cardinality(cardinality)
          .script(new Script(map.get(cardinality)));
        sourceBuilder.aggregation(cardinalityAggregationBuilder);
      }
    }
    // 有 group by 查询
    else {
      // 因为 groupBy 操作为层层嵌套,需要有 getIndex 操作(为何需要这个操作?请看下面分解),所以将去重之后的 groupBy 转为list
      AggregationBuilder termsAggregationBuilder = buildGroupBy(groupByList, sumFieldList, cardinality);
      sourceBuilder.aggregation(termsAggregationBuilder);
    }

    SearchRequest searchRequest = new SearchRequest();
    // 查询索引
    searchRequest.indices("此处为索引");
    searchRequest.types("此处为type");
    searchRequest.source(sourceBuilder);
    SearchResponse searchResponse = elasticsearchClientHolder.getClient().search(searchRequest, RequestOptions.DEFAULT);

    // 最终可用的统计数据
    List<Map<String, String>> list;
    // 有 group by 查询
    if (!CollectionUtils.isEmpty(groupByList)) {
      list = collectResponse(searchResponse, groupByList, sumFieldList, docCount, cardinality);
    }
    // 无 group by 查询
    else {
      list = noGroupByCollectResponse(searchResponse, sumFieldList, docCount, cardinality);
    }
  }
}
/**
 * 构建分组参数,支持多字段 group by
 *
 * @param groupBy      分组字段
 * @param sumFieldList 聚合字段
 * @param cardinality  count(distinct ...)
 * @return org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder
 * @author fageiguanbing
 * @date 2019/7/16
 **/
private AggregationBuilder buildGroupBy(List<String> groupBy, List<String> sumFieldList, String cardinality) {
  // 不影响外部list
  ArrayList<String> groupByList = new ArrayList<>(groupBy);

  AggregationBuilder termsAggregationBuilder = AggregationBuilders.terms(groupByList.get(0)).field(groupByList.get(0)).size(0);
  groupByList.remove(0);
  // 多个groupby
  if (!CollectionUtils.isEmpty(groupByList)) {
    buildTermsAggregation(termsAggregationBuilder, groupByList, new ArrayList<>(sumFieldList), cardinality);
  }
  // 只剩一个groupby,后接sum
  else {
    buildStatisticalFieldBySource(termsAggregationBuilder, new ArrayList<>(sumFieldList), cardinality);
  }
  return termsAggregationBuilder;
}
/**
 * 
 *
 * @param aggregationBuilder aggregationBuilder
 * @param sumField           聚合字段
 * @param cardinality        count(distinct ...)
 * @author fageiguanbing
 * @date 2019/7/17
 **/
private void buildStatisticalFieldBySource(AggregationBuilder aggregationBuilder, Object sumField, String cardinality) {
  List<String> list = new ArrayList<>();
  if (sumField instanceof String) {
    list.add(sumField.toString());
  } else {
    list = (ArrayList) sumField;
  }

  AggregatorFactories.Builder builder = new AggregatorFactories.Builder();
  // 构造 sum 脚本
  for (String sum : list) {
    builder.addAggregator(AggregationBuilders.sum(sum)
                          .script(new Script(map.get(sum))));
  }
  // 构造 count(distinct ...)
  if (StringUtils.isNotBlank(cardinality)) {
    CardinalityAggregationBuilder cardinalityAggregationBuilder = AggregationBuilders
      .cardinality(cardinality)
      .script(new Script(map.get(cardinality)));
    builder.addAggregator(cardinalityAggregationBuilder);
  }
  aggregationBuilder.subAggregations(builder);
}
/**
 * 递归构造 aggregationBuilder
 *
 * @param aggregationBuilder aggregationBuilder
 * @param groupBy            分组字段
 * @param sumField           聚合字段
 * @param cardinality        count(distinct ...)
 * @author fageiguanbing
 * @date 2019/7/17
 **/
private void buildTermsAggregation(AggregationBuilder aggregationBuilder, List<String> groupBy, Object sumField, String cardinality) {
  TermsAggregationBuilder builder = AggregationBuilders.terms(groupBy.get(0)).field(groupBy.get(0)).size(0);
  groupBy.remove(0);
  // 多个 groupBy 递归
  if (!CollectionUtils.isEmpty(groupBy)) {
    buildTermsAggregation(builder, groupBy, sumField, cardinality);
  }
  // 一个 groupBy
  else {
    buildStatisticalFieldBySource(builder, sumField, cardinality);
  }
  aggregationBuilder.subAggregation(builder);
}
/**
 * 构建查询条件
 *
 * @param request request
 * @return org.elasticsearch.index.query.BoolQueryBuilder
 * @author fageiguanbing
 * @date 2019/7/16
 **/
private BoolQueryBuilder buildQueryCondition(JSONObject request) {
  BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery();
  for (String key : request.keySet()) {
    Object value = request.get(key);
    if (value instanceof String && !org.springframework.util.StringUtils.isEmpty(value)) {
      queryBuilder.filter(QueryBuilders.termQuery(key, value.toString()));
    } else if (value instanceof List && !CollectionUtils.isEmpty((List) value)) {
      queryBuilder.filter(QueryBuilders.termsQuery(key, (List) value));
    }
  }
  return queryBuilder;
}
/**
 * 收集 elasticsearch 返回值
 *
 * @param searchResponse   elasticsearch 返回值
 * @param groupByFieldList 分组参数
 * @param sumFieldList     聚合参数
 * @param docCount         count(1)结果
 * @param cardinality      count(distinct ...)
 * @return java.util.List<java.util.Map < String, String>>
 * @author fageiguanbing
 * @date 2019/7/18
 **/
private List<Map<String, String>> collectResponse(SearchResponse searchResponse, List<String> groupByFieldList, List<String> sumFieldList, String docCount, String cardinality) {
  List<Map<String, String>> returnList = new ArrayList<>();

  Terms terms = searchResponse.getAggregations().get(groupByFieldList.get(0));

  List<? extends Terms.Bucket> buckets = terms.getBuckets();

  if (!CollectionUtils.isEmpty(buckets)) {
    recursionCollectResponse(buckets, groupByFieldList, 0, sumFieldList, returnList, new HashMap<>(), docCount, cardinality);
  }

  return returnList;
}
/**
 * 递归获取返回值
 *
 * @param buckets          buckets
 * @param groupByFieldList 分组参数
 * @param position         控制递归参数
 * @param sumFieldList     聚合参数
 * @param returnList       结果集
 * @param paramMap         保存递归过程中必要参数
 * @param docCount         count(1) 结果
 * @param cardinality      count(distinct ...)
 * @author fageiguanbing
 * @date 2019/7/18
 **/
private void recursionCollectResponse(List<? extends Terms.Bucket> buckets,
                                      List<String> groupByFieldList,
                                      int position,
                                      List<String> sumFieldList,
                                      List<Map<String, String>> returnList,
                                      Map<String, String> paramMap, String docCount,
                                      String cardinality) {
  // 倒数第二层
  if (groupByFieldList.size() == position + 1) {
    for (Terms.Bucket bucket : buckets) {
      Map<String, String> map = new HashMap<>();
      map.put(groupByFieldList.get(position), bucket.getKey().toString());
      for (String paramKey : paramMap.keySet()) {
        map.put(paramKey, paramMap.get(paramKey));
      }
      // 最后一层,也是聚合字段所在位置
      if (StringUtils.isNotBlank(docCount)) {
        map.put(docCount, new BigDecimal(bucket.getDocCount()).toPlainString());
      }

      // count(distinct mch_id)
      if (StringUtils.isNotBlank(cardinality)) {
        Cardinality cardinalityResponse = bucket.getAggregations().get(cardinality);
        String count = new BigDecimal(cardinalityResponse.getValue()).toPlainString();
        map.put(cardinality, count);
      }
      for (String sumField : sumFieldList) {
        Sum sum = bucket.getAggregations().get(sumField);
        map.put(sumField, new BigDecimal(sum.getValue()).toPlainString());
      }
      returnList.add(map);
    }
    return;
  }

  String groupBy = groupByFieldList.get(position);
  for (Terms.Bucket bucket : buckets) {
    paramMap.put(groupBy, bucket.getKey().toString());
    Terms terms = bucket.getAggregations().get(groupByFieldList.get(++position));
    recursionCollectResponse(terms.getBuckets(), groupByFieldList, position, sumFieldList, returnList, paramMap, docCount, cardinality);
    --position;
  }
}
/**
 * 无groupBy构造返回值
 *
 * @param searchResponse elasticsearch 返回值
 * @param sumFieldList   聚合字段
 * @param docCount       count(1)
 * @param cardinality    count(distinct ...)
 * @return java.util.List<java.util.Map < String, String>>
 * @author fageiguanbing
 * @date 2019/7/18
 **/
private List<Map<String, String>> noGroupByCollectResponse(SearchResponse searchResponse, List<String> sumFieldList, String docCount, String cardinality) {
  List<Map<String, String>> list = new ArrayList<>();
  Map<String, String> map = new HashMap<>();
  for (String sumField : sumFieldList) {
    Sum sum = searchResponse.getAggregations().get(sumField);
    map.put(sumField, new BigDecimal(sum.getValue()).toPlainString());
  }

  // 没有 group by 则为条件查询结果
  if (StringUtils.isNotBlank(docCount)) {
    long totalHits = searchResponse.getHits().getTotalHits();
    map.put(docCount, String.valueOf(totalHits));
  }

  // count(distinct ...)
  if (StringUtils.isNotBlank(cardinality)) {
    Cardinality cardinalityResponse = searchResponse.getAggregations().get(cardinality);
    String count = new BigDecimal(cardinalityResponse.getValue()).toPlainString();
    map.put(cardinality, count);
  }
  list.add(map);
  return list;
}

本文由 发给官兵 创作,采用 CC BY 3.0 CN协议 进行许可。 可自由转载、引用,但需署名作者且注明文章出 处。如转载至微信公众号,请在文末添加作者公众号二维码。