`
zangwenyang
  • 浏览: 124560 次
  • 性别: Icon_minigender_1
  • 来自: 北京
社区版块
存档分类
最新评论

NMF(非负矩阵分解)的SGD(随机梯度下降)实现

 
阅读更多

NMF把一个矩阵分解为两个矩阵的乘积,可以用来解决很多问题,例如:用户聚类、item聚类、预测(补全)用户对item的评分、个性化推荐等问题。NMF的过程可以转化为最小化损失函数(即误差函数)的过程,其实整个问题也就是一个最优化的问题。详细实现过程如下:(其中,输入矩阵很多时候会比较稀疏,即很多元素都是缺失项,故数据存储采用的是libsvm的格式,这个类在此忽略)

 

 

[java] view plaincopy
 
  1. package NMF_danji;  
  2.   
  3. import java.io.File;  
  4. import java.util.ArrayList;  
  5.   
  6. /** 
  7.  * @author 玉心sober: http://weibo.com/karensober 
  8.  * @date 2013-05-19 
  9.  *  
  10.  * */  
  11. public class NMF {  
  12.     private Dataset dataset = null;  
  13.     private int M = -1// 行数  
  14.     private int V = -1// 列数  
  15.     private int K = -1// 隐含主题数  
  16.     double[][] P;  
  17.     double[][] Q;  
  18.   
  19.     public NMF(String datafileName, int topics) {  
  20.         File datafile = new File(datafileName);  
  21.         if (datafile.exists()) {  
  22.             if ((this.dataset = new Dataset(datafile)) == null) {  
  23.                 System.out.println(datafileName + " is null");  
  24.             }  
  25.             this.M = this.dataset.size();  
  26.             this.V = this.dataset.getFeatureNum();  
  27.             this.K = topics;  
  28.         } else {  
  29.             System.out.println(datafileName + " doesn't exist");  
  30.         }  
  31.     }  
  32.   
  33.     public void initPQ() {  
  34.         P = new double[this.M][this.K];  
  35.         Q = new double[this.K][this.V];  
  36.   
  37.         for (int k = 0; k < K; k++) {  
  38.             for (int i = 0; i < M; i++) {  
  39.                 P[i][k] = Math.random();  
  40.             }  
  41.             for (int j = 0; j < V; j++) {  
  42.                 Q[k][j] = Math.random();  
  43.             }  
  44.         }  
  45.     }  
  46.   
  47.     // 随机梯度下降,更新参数  
  48.     public void updatePQ(double alpha, double beta) {  
  49.         for (int i = 0; i < M; i++) {  
  50.             ArrayList<Feature> Ri = this.dataset.getDataAt(i).getAllFeature();  
  51.             for (Feature Rij : Ri) {  
  52.                 // eij=Rij.weight-PQ for updating P and Q  
  53.                 double PQ = 0;  
  54.                 for (int k = 0; k < K; k++) {  
  55.                     PQ += P[i][k] * Q[k][Rij.dim];  
  56.                 }  
  57.                 double eij = Rij.weight - PQ;  
  58.   
  59.                 // update Pik and Qkj  
  60.                 for (int k = 0; k < K; k++) {  
  61.                     double oldPik = P[i][k];  
  62.                     P[i][k] += alpha  
  63.                             * (2 * eij * Q[k][Rij.dim] - beta * P[i][k]);  
  64.                     Q[k][Rij.dim] += alpha  
  65.                             * (2 * eij * oldPik - beta * Q[k][Rij.dim]);  
  66.                 }  
  67.             }  
  68.         }  
  69.     }  
  70.   
  71.     // 每步迭代后计算SSE  
  72.     public double getSSE(double beta) {  
  73.         double sse = 0;  
  74.         for (int i = 0; i < M; i++) {  
  75.             ArrayList<Feature> Ri = this.dataset.getDataAt(i).getAllFeature();  
  76.             for (Feature Rij : Ri) {  
  77.                 double PQ = 0;  
  78.                 for (int k = 0; k < K; k++) {  
  79.                     PQ += P[i][k] * Q[k][Rij.dim];  
  80.                 }  
  81.                 sse += Math.pow((Rij.weight - PQ), 2);  
  82.             }  
  83.         }  
  84.   
  85.         for (int i = 0; i < M; i++) {  
  86.             for (int k = 0; k < K; k++) {  
  87.                 sse += ((beta / 2) * (Math.pow(P[i][k], 2)));  
  88.             }  
  89.         }  
  90.   
  91.         for (int i = 0; i < V; i++) {  
  92.             for (int k = 0; k < K; k++) {  
  93.                 sse += ((beta / 2) * (Math.pow(Q[k][i], 2)));  
  94.             }  
  95.         }  
  96.   
  97.         return sse;  
  98.     }  
  99.   
  100.     // 采用随机梯度下降方法迭代求解参数,即求解最终分解后的矩阵  
  101.     public boolean doNMF(int iters, double alpha, double beta) {  
  102.         for (int step = 0; step < iters; step++) {  
  103.             updatePQ(alpha, beta);  
  104.             double sse = getSSE(beta);  
  105.             if (step % 100 == 0)  
  106.                 System.out.println("step " + step + " SSE = " + sse);  
  107.         }  
  108.         return true;  
  109.     }  
  110.   
  111.     public void printMatrix() {  
  112.         System.out.println("===========原始矩阵==============");  
  113.         for (int i = 0; i < this.dataset.size(); i++) {  
  114.             for (Feature feature : this.dataset.getDataAt(i).getAllFeature()) {  
  115.                 System.out.print(feature.dim + ":" + feature.weight + " ");  
  116.             }  
  117.             System.out.println();  
  118.         }  
  119.     }  
  120.   
  121.     public void printFacMatrxi() {  
  122.         System.out.println("===========分解矩阵==============");  
  123.         for (int i = 0; i < P.length; i++) {  
  124.             for (int j = 0; j < Q[0].length; j++) {  
  125.                 double cell = 0;  
  126.                 for (int k = 0; k < K; k++) {  
  127.                     cell += P[i][k] * Q[k][j];  
  128.                 }  
  129.                 System.out.print(baoliu(cell, 3) + " ");  
  130.             }  
  131.             System.out.println();  
  132.         }  
  133.     }  
  134.   
  135.     // 为double类型变量保留有效数字  
  136.     public static double baoliu(double d, int n) {  
  137.         double p = Math.pow(10, n);  
  138.         return Math.round(d * p) / p;  
  139.     }  
  140.   
  141.     public static void main(String[] args) {  
  142.         double alpha = 0.002;  
  143.         double beta = 0.02;  
  144.   
  145.         NMF nmf = new NMF("D:\\myEclipse\\graphModel\\data\\nmfinput.txt"10);  
  146.         nmf.initPQ();  
  147.         nmf.doNMF(3000, alpha, beta);  
  148.   
  149.         // 输出原始矩阵  
  150.         nmf.printMatrix();  
  151.   
  152.         // 输出分解后矩阵  
  153.         nmf.printFacMatrxi();  
  154.     }  
  155. }  

结果:
...

 

step 2900 SSE = 0.5878774074369989
===========原始矩阵==============
0:9.0 1:2.0 2:1.0 3:1.0 4:1.0 
0:8.0 1:3.0 2:2.0 3:1.0 
0:3.0 3:1.0 4:2.0 5:8.0 
1:1.0 3:2.0 4:4.0 5:7.0 
0:2.0 1:1.0 2:1.0 4:1.0 5:3.0 
===========分解矩阵==============
8.959 2.007 1.007 0.996 1.007 6.293 
7.981 2.972 1.989 1.005 2.046 7.076 
3.01 1.601 1.773 1.003 2.005 7.968 
4.821 1.009 2.209 1.984 3.968 6.988 
2.0 0.991 0.984 0.51 1.0 2.994 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics