• Spring Boot Spark Kmeans 알고리즘 사용법

    2022. 6. 7.

    by. 순일

    Kmeans란?

    Kmeans란 K-평균 알고리즘을 주어진 데이터를 K개의 클러스터로 묶는 알고리즘입니다. 각 클러스터와 거리 차이 및 분산을 최소화하여 비슷한 유형끼리 그룹화 함으로써 라벨이 달려 있지 않는 데이터에 라벨을 달아주는 역할을 수행합니다.

     

    유기동물을 매칭 해주는 프로젝트를 진행하면서 내가 입력한 실종 동물의 정보와 유사한 유기동물을 매칭 해주는 기능을 만들기 위해 K means 알고리즘을 사용하기로 하였다. 입력 데이터를 Vector화(String을 double 형으로 변환.) 하는 작업이 추가적으로 필요하다. 

     

    1. 먼저 spakr의 mllib를 dependencies에 추가하여야 한다.
     

    Maven Repository: org.apache.spark » spark-mllib_2.13 » 3.2.0

     

    mvnrepository.com

    maven 사이트에서 추천하는 방식 은 아래와 같다.

    compileOnly group: 'org.apache.spark', name: 'spark-mllib_2.12', version: '3.1.3'

    하지만 slf4j-log4j12 모듈을 이미 사용하고 있기 때문에 제외(exclude group) 시켜줘야 한다.

    compileOnly ('org.apache.spark:spark-mllib_2.12:2.4.3'){
    	exclude group : "org.slf4j", module : "slf4j-log4j12"
    }

    하지만 위와같이 변환해도 에러가 발생하는데 더 이상 해당 버전을 지원하지 않는다는 걸 알게 되었고 아래와 같은 방식으로 변환해서 사용해야 한다.

    // spark kmeans
    // https://mvnrepository.com/artifact/org.apache.spark/spark-mllib
    implementation ('org.apache.spark:spark-mllib_2.12:3.1.3') {
    	exclude group: "org.slf4j", module: "slf4j-log4j12"
    }

    만약 해당 모듈을(slf4 j-log4 j12) 이미 사용하고 있지 않다면 아래와 같은 방식으로 gradle을 추가해주면 된다.

    implementation 'org.apache.spark:spark-mllib_2.12:3.1.3'

     

    2. Kmeans 기본 예제 사용하되 응용하여 코드를 작성한다.
     

    Clustering - RDD-based API - Spark 2.2.0 Documentation

    You are using an outdated browser. Upgrade your browser today or install Google Chrome Frame to better experience this site. Overview Programming Guides API Docs Deploying More v2.2.0 -->

    spark.apache.org

    [기본 예제]

    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.mllib.clustering.GaussianMixture;
    import org.apache.spark.mllib.clustering.GaussianMixtureModel;
    import org.apache.spark.mllib.linalg.Vector;
    import org.apache.spark.mllib.linalg.Vectors;
    
    // Load and parse data
    String path = "data/mllib/gmm_data.txt";
    JavaRDD<String> data = jsc.textFile(path);
    JavaRDD<Vector> parsedData = data.map(s -> {
      String[] sarray = s.trim().split(" ");
      double[] values = new double[sarray.length];
      for (int i = 0; i < sarray.length; i++) {
        values[i] = Double.parseDouble(sarray[i]);
      }
      return Vectors.dense(values);
    });
    parsedData.cache();
    
    // Cluster the data into two classes using GaussianMixture
    GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
    
    // Save and load GaussianMixtureModel
    gmm.save(jsc.sc(), "target/org/apache/spark/JavaGaussianMixtureExample/GaussianMixtureModel");
    GaussianMixtureModel sameModel = GaussianMixtureModel.load(jsc.sc(),
      "target/org.apache.spark.JavaGaussianMixtureExample/GaussianMixtureModel");
    
    // Output the parameters of the mixture model
    for (int j = 0; j < gmm.k(); j++) {
      System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n",
        gmm.weights()[j], gmm.gaussians()[j].mu(), gmm.gaussians()[j].sigma());
    }

    [기본 예제를 응용한 소스 코드]

    	static SparkConf conf = new SparkConf().setAppName("JavaKMeansExample").setMaster("local").set("spark.driver.allowMultipleContexts","true");
    	static JavaSparkContext jsc = new JavaSparkContext(conf);
    
    	@Override
    	public List<DatasetDto> matching(Map<String, String> map) {
        		...
    		// CSV 파일 읽고 데이터 저장
    		String path = "/recommend.csv";
    		JavaRDD<String> data = jsc.textFile(path);
    
    		
    		// 데이터를 RDD로 변환 
    		JavaRDD<Vector> parsedData = data.map(s -> {
    			String[] sarray = s.split("\\|");
    			double[] values = new double[sarray.length]; 
    			for (int i = 0; i < sarray.length; i++) {
    				values[i] = Double.parseDouble(sarray[i]);
    			}
    			return Vectors.dense(values);
    		});
    
    		// 반복 알고리즘에 대해 데이터 캐쉬로 속도 향상 시켜줌
    		parsedData.cache();
    
    		// 클러스터 갯수와 반복 횟수 정의
    		int numClusters = (int) data.count();
    		int numIterations = 20;
    
    		KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations);
    
    		List<DatasetDto> result = new ArrayList<DatasetDto>();
    
    		List<Integer> result_int = clusters.predict(parsedData).collect();
    		for (int j = 1; j < result_int.size(); j++) {
    			if (result_int.get(j) == result_int.get(0))
    				result.add(list.get(j));
    		}
     
    		result.sort((a,b) -> b.getHappenDt().compareTo(a.getHappenDt()));
    
    		return result;
    	}

    프로젝트에서 총 4가지의 입력값으로 비교를 했는데 품종, 지역, 성별, 생상을 입력받으면 먼저 데이터베이스에 저장된 값들 중에 품종이 같은 것만 가져온다. 이때 리스트에 첫 번째 값으로 유저가 입력한 값을 넣어준다. 이후 지역, 성별, 생상은 자체적으로 벡터화하여 recommend.csv파일을 만들어 준다.

    이후 recommend.csv를 읽어서 RDD로 변환한 뒤 Kmeans를 돌리게 되고 그 결괏값을 가지고 첫번째 값과 같은 결괏값을 가지는 RDD를 모두 리스트에 담아 반환하는 방식으로 구현하였다.

     

    참고 사이트
    728x90

    '공부 > BackEnd' 카테고리의 다른 글

    Spring Boot @Scheduled 스케줄링 사용법  (0) 2022.04.17
    Spring Boot Web Socket 실시간 알림  (0) 2022.02.21
    Spring Boot 복잡한 JSON 파싱  (0) 2022.02.13

    댓글