K-SVD算法的基本思想:
Y为训练样本,D为字典,X为稀疏系数。
一般分为
Sparse Coding
和
DictionaryUpdate
两个步骤:
1
:
Sparse Coding
:固定字典
D
通过下面的目标函数采用一种追踪算法找到样本的最佳稀疏矩阵。
2
:
Dictionary Update
:按列更新字典,一句可使
MSE
减少的准则,通过
SVD
(奇异值分解)循序的更新每一列和该列对应的稀疏矩阵的值。
E
K为字典的第k列的残差,物理意义:没有d
k时表示的误差,也就是字典的第k列在表示Y的过程中究竟起到了多大的作用。
根据上面的E
K的解释可以知道,我们的目的就是找到一个合适的d
k来最大化减小E
K。
为了得到d
k就需要对E
K 进行SVD(奇异值分解),
E
k
=UΔV
T
令矩阵U的第一列作为字典第K列更新后的d
k
,同时令Δ(1,1)乘以V的第一列作为更新后的稀疏系数。
下面是一个简单的利用KSVD和OMP算法的演示代码
代码流程:
Step1:读入的一张lena图片img
Step2: 随机生成一个测量矩阵phi
Step3:y=phi*img得到观测值
Step4:利用[Dictionary,]=KSVD[img,para]得到dictionary
Step5:利用A=OMP[phi*Dictionary,y,L]得到稀疏系数矩阵
Step6:img_rec=Dictionary*A得到重建的图像。
Demo_Code_1.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 11.
2%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
31.
4% the K-SVD basis is selected as the sparse representation dictionary
51.
6% the OMP algorithm is used to recover the image
71.
8% Author: zhang ben, ncuzhangben@qq.com
91.
10%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
111.
12%\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\* read in the image \*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*
131.
14img=imread('lena.bmp'); % read in the image "lena.bmp"
151.
16img=double(img);
171.
18[N,n]=size(img);
191.
20img0 = img; % keep an original copy of the input signal
211.
22%\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*form the measurement matrix and Dictionary \*\*\*\*\*\*\*\*\*\*\*\*\*\*\*
231.
24%form the measurement matrix Phi
251.
26Phi=randn(N,n);
271.
28Phi = Phi./repmat(sqrt(sum(Phi.^2,1)),[N,1]); % normalize each column
291.
30%fix the parameters
311.
32param.L =20; % number of elements in each linear combination.
331.
34param.K =150; %number of dictionary elements
351.
36param.numIteration = 50; % number of iteration to execute the K-SVD algorithm.
371.
38param.errorFlag = 0; % decompose signals until a certain error is reached.
391.
40 %do not use fix number of coefficients.
411.
42%param.errorGoal = sigma;
431.
44param.preserveDCAtom = 0;
451.
46param.InitializationMethod ='DataElements';%initialization by the signals themselves
471.
48param.displayProgress = 1; % progress information is displyed.
491.
50[Dictionary,output]= KSVD(img,param);%Dictionary is N\*param.K
511.
52%\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\* projection \*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*
531.
54y=Phi\*img; % treat each column as a independent signal
551.
56y0=y; % keep an original copy of the measurements
571.
58%\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\* recover using OMP \*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*
591.
60D=Phi\*Dictionary;
611.
62A=OMP(D,y,20);
631.
64imgr=Dictionary\*A;
651.
66%\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\* show the results \*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*
671.
68figure(1)
691.
70subplot(2,2,1),imagesc(img0),title('original image')
711.
72subplot(2,2,2),imagesc(y0),title('measurement image')
731.
74subplot(2,2,3),imagesc(Dictionary),title('Dictionary')
751.
76psnr=20\*log10(255/sqrt(mean((img(:)-imgr(:)).^2)));
771.
78subplot(2,2,4),imagesc(imgr),title(strcat('recover image (',num2str(psnr),'dB)'))
791.
80disp('over')
81
82
OMP.m(这是网友写好的代码)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 11.
2function [A]=OMP(D,X,L);
31.
4%=============================================
51.
6% Sparse coding of a group of signals based on a given
71.
8% dictionary and specified number of atoms to use.
91.
10% input arguments:
111.
12% D - the dictionary (its columns MUST be normalized).
131.
14% X - the signals to represent
151.
16% L - the max. number of coefficients for each signal.
171.
18% output arguments:
191.
20% A - sparse coefficient matrix.
211.
22%=============================================
231.
24[n,K]=size(D);
251.
26[n,P]=size(X);
271.
28for k=1:1:P,
291.
30 a=[];
311.
32 x=X(:,k);%令向量x等于矩阵X的第K列的元素长度为n\*1
331.
34 residual=x;%n\*1
351.
36 indx=zeros(L,1);%L\*1的0矩阵
371.
38 for j=1:1:L,
391.
40 proj=D'\*residual;%K\*n n\*1 变成K\*1
411.
42 [maxVal,pos]=max(abs(proj));% 最大投影系数对应的位置
431.
44 pos=pos(1);
451.
46 indx(j)=pos;
471.
48 a=pinv(D(:,indx(1:j)))\*x;
491.
50 residual=x-D(:,indx(1:j))\*a;
511.
52 if sum(residual.^2) < 1e-6
531.
54 break;
551.
56 end
571.
58 end;
591.
60 temp=zeros(K,1);
611.
62 temp(indx(1:j))=a;
631.
64 A(:,k)=sparse(temp);%A为返回为K\*P的矩阵
651.
66end;
671.
68return;
69
70
KSVD算法实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461 11.
2function [Dictionary,output] = KSVD(...
31.
4 Data,... % an nXN matrix that contins N signals (Y), each of dimension n.
51.
6 param)
71.
8% =========================================================================
91.
10% K-SVD algorithm
111.
12% =========================================================================
131.
14% The K-SVD algorithm finds a dictionary for linear representation of
151.
16% signals. Given a set of signals, it searches for the best dictionary that
171.
18% can sparsely represent each signal. Detailed discussion on the algorithm
191.
20% and possible applications can be found in "The K-SVD: An Algorithm for
211.
22% Designing of Overcomplete Dictionaries for Sparse Representation", written
231.
24% by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans.
251.
26% On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006.
271.
28% =========================================================================
291.
30% INPUT ARGUMENTS:
311.
32% Data an nXN matrix that contins N signals (Y), each of dimension n.
331.
34% param structure that includes all required
351.
36% parameters for the K-SVD execution.
371.
38% Required fields are:
391.
40% K, ... the number of dictionary elements to train
411.
42% numIteration,... number of iterations to perform.
431.
44% errorFlag... if =0, a fix number of coefficients is
451.
46% used for representation of each signal. If so, param.L must be
471.
48% specified as the number of representing atom. if =1, arbitrary number
491.
50% of atoms represent each signal, until a specific representation error
511.
52% is reached. If so, param.errorGoal must be specified as the allowed
531.
54% error.
551.
56% preserveDCAtom... if =1 then the first atom in the dictionary
571.
58% is set to be constant, and does not ever change. This
591.
60% might be useful for working with natural
611.
62% images (in this case, only param.K-1
631.
64% atoms are trained).
651.
66% (optional, see errorFlag) L,... % maximum coefficients to use in OMP coefficient calculations.
671.
68% (optional, see errorFlag) errorGoal, ... % allowed representation error in representing each signal.
691.
70% InitializationMethod,... mehtod to initialize the dictionary, can
711.
72% be one of the following arguments:
731.
74% \* 'DataElements' (initialization by the signals themselves), or:
751.
76% \* 'GivenMatrix' (initialization by a given matrix param.initialDictionary).
771.
78% (optional, see InitializationMethod) initialDictionary,... % if the initialization method
791.
80% is 'GivenMatrix', this is the matrix that will be used.
811.
82% (optional) TrueDictionary, ... % if specified, in each
831.
84% iteration the difference between this dictionary and the trained one
851.
86% is measured and displayed.
871.
88% displayProgress, ... if =1 progress information is displyed. If param.errorFlag==0,
891.
90% the average repersentation error (RMSE) is displayed, while if
911.
92% param.errorFlag==1, the average number of required coefficients for
931.
94% representation of each signal is displayed.
951.
96% =========================================================================
971.
98% OUTPUT ARGUMENTS:
991.
100% Dictionary The extracted dictionary of size nX(param.K).
1011.
102% output Struct that contains information about the current run. It may include the following fields:
1031.
104% CoefMatrix The final coefficients matrix (it should hold that Data equals approximately Dictionary\*output.CoefMatrix.
1051.
106% ratio If the true dictionary was defined (in
1071.
108% synthetic experiments), this parameter holds a vector of length
1091.
110% param.numIteration that includes the detection ratios in each
1111.
112% iteration).
1131.
114% totalerr The total representation error after each
1151.
116% iteration (defined only if
1171.
118% param.displayProgress=1 and
1191.
120% param.errorFlag = 0)
1211.
122% numCoef A vector of length param.numIteration that
1231.
124% include the average number of coefficients required for representation
1251.
126% of each signal (in each iteration) (defined only if
1271.
128% param.displayProgress=1 and
1291.
130% param.errorFlag = 1)
1311.
132% =========================================================================
1331.
134
1351.
136if (~isfield(param,'displayProgress'))
1371.
138 param.displayProgress = 0;
1391.
140end
1411.
142totalerr(1) = 99999;
1431.
144if (isfield(param,'errorFlag')==0)
1451.
146 param.errorFlag = 0;
1471.
148end
1491.
150
1511.
152if (isfield(param,'TrueDictionary'))
1531.
154 displayErrorWithTrueDictionary = 1;
1551.
156 ErrorBetweenDictionaries = zeros(param.numIteration+1,1); %产生零矩阵
1571.
158 ratio = zeros(param.numIteration+1,1);
1591.
160else
1611.
162 displayErrorWithTrueDictionary = 0;
1631.
164 ratio = 0;
1651.
166end
1671.
168if (param.preserveDCAtom>0)
1691.
170 FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));
1711.
172else
1731.
174 FixedDictionaryElement = [];
1751.
176end
1771.
178% coefficient calculation method is OMP with fixed number of coefficients
1791.
180
1811.
182if (size(Data,2) < param.K)
1831.
184 disp('Size of data is smaller than the dictionary size. Trivial solution...');
1851.
186 Dictionary = Data(:,1:size(Data,2));
1871.
188 return;
1891.
190elseif (strcmp(param.InitializationMethod,'DataElements'))
1911.
192 Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);
1931.
194elseif (strcmp(param.InitializationMethod,'GivenMatrix'))
1951.
196 Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);
1971.
198end
1991.
200% reduce the components in Dictionary that are spanned by the fixed
2011.
202% elements
2031.
204if (param.preserveDCAtom)
2051.
206 tmpMat = FixedDictionaryElement \ Dictionary;
2071.
208 Dictionary = Dictionary - FixedDictionaryElement\*tmpMat;
2091.
210end
2111.
212%normalize the dictionary.
2131.
214Dictionary = Dictionary\*diag(1./sqrt(sum(Dictionary.\*Dictionary)));
2151.
216Dictionary = Dictionary.\*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.
2171.
218totalErr = zeros(1,param.numIteration);
2191.
220
2211.
222% the K-SVD algorithm starts here.
2231.
224
2251.
226for iterNum = 1:param.numIteration
2271.
228 % find the coefficients
2291.
230 if (param.errorFlag==0)
2311.
232 %CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);
2331.
234 CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);
2351.
236 else
2371.
238 %CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);
2391.
240 CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal);
2411.
242 param.L = 1;
2431.
244 end
2451.
246
2471.
248 replacedVectorCounter = 0;
2491.
250 rPerm = randperm(size(Dictionary,2));
2511.
252 for j = rPerm
2531.
254 [betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...
2551.
256 [FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...
2571.
258 CoefMatrix ,param.L);
2591.
260 Dictionary(:,j) = betterDictionaryElement;
2611.
262 if (param.preserveDCAtom)
2631.
264 tmpCoef = FixedDictionaryElement\betterDictionaryElement;
2651.
266 Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement\*tmpCoef;
2671.
268 Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'\*Dictionary(:,j));
2691.
270 end
2711.
272 replacedVectorCounter = replacedVectorCounter+addedNewVector;
2731.
274 end
2751.
276
2771.
278 if (iterNum>1 & param.displayProgress)
2791.
280 if (param.errorFlag==0)
2811.
282 output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]\*CoefMatrix).^2))/prod(size(Data)));
2831.
284 disp(['Iteration ',num2str(iterNum),' Total error is: ',num2str(output.totalerr(iterNum-1))]);
2851.
286 else
2871.
288 output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);
2891.
290 disp(['Iteration ',num2str(iterNum),' Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);
2911.
292 end
2931.
294 end
2951.
296 if (displayErrorWithTrueDictionary )
2971.
298 [ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);
2991.
300 disp(strcat(['Iteration ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));
3011.
302 output.ratio = ratio;
3031.
304 end
3051.
306 Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);
3071.
308
3091.
310 if (isfield(param,'waitBarHandle'))
3111.
312 waitbar(iterNum/param.counterForWaitBar);
3131.
314 end
3151.
316end
3171.
318
3191.
320output.CoefMatrix = CoefMatrix;
3211.
322Dictionary = [FixedDictionaryElement,Dictionary];
3231.
324%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3251.
326% findBetterDictionaryElement
3271.
328%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3291.
330
3311.
332function [betterDictionaryElement,CoefMatrix,NewVectorAdded] = I_findBetterDictionaryElement(Data,Dictionary,j,CoefMatrix,numCoefUsed)
3331.
334if (length(who('numCoefUsed'))==0)
3351.
336 numCoefUsed = 1;
3371.
338end
3391.
340relevantDataIndices = find(CoefMatrix(j,:)); % the data indices that uses the j'th dictionary element.
3411.
342if (length(relevantDataIndices)<1) %(length(relevantDataIndices)==0)
3431.
344 ErrorMat = Data-Dictionary\*CoefMatrix;
3451.
346 ErrorNormVec = sum(ErrorMat.^2);
3471.
348 [d,i] = max(ErrorNormVec);
3491.
350 betterDictionaryElement = Data(:,i);%ErrorMat(:,i); %
3511.
352 betterDictionaryElement = betterDictionaryElement./sqrt(betterDictionaryElement'\*betterDictionaryElement);
3531.
354 betterDictionaryElement = betterDictionaryElement.\*sign(betterDictionaryElement(1));
3551.
356 CoefMatrix(j,:) = 0;
3571.
358 NewVectorAdded = 1;
3591.
360 return;
3611.
362end
3631.
364
3651.
366NewVectorAdded = 0;
3671.
368tmpCoefMatrix = CoefMatrix(:,relevantDataIndices);
3691.
370tmpCoefMatrix(j,:) = 0;% the coeffitients of the element we now improve are not relevant.
3711.
372errors =(Data(:,relevantDataIndices) - Dictionary\*tmpCoefMatrix); % vector of errors that we want to minimize with the new element
3731.
374% % the better dictionary element and the values of beta are found using svd.
3751.
376% % This is because we would like to minimize || errors - beta\*element ||_F^2.
3771.
378% % that is, to approximate the matrix 'errors' with a one-rank matrix. This
3791.
380% % is done using the largest singular value.
3811.
382[betterDictionaryElement,singularValue,betaVector] = svds(errors,1);
3831.
384CoefMatrix(j,relevantDataIndices) = singularValue\*betaVector';% \*signOfFirstElem
3851.
386
3871.
388%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3891.
390% findDistanseBetweenDictionaries
3911.
392%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3931.
394function [ratio,totalDistances] = I_findDistanseBetweenDictionaries(original,new)
3951.
396% first, all the column in oiginal starts with positive values.
3971.
398catchCounter = 0;
3991.
400totalDistances = 0;
4011.
402for i = 1:size(new,2)
4031.
404 new(:,i) = sign(new(1,i))\*new(:,i);
4051.
406end
4071.
408for i = 1:size(original,2)
4091.
410 d = sign(original(1,i))\*original(:,i);
4111.
412 distances =sum ( (new-repmat(d,1,size(new,2))).^2);
4131.
414 [minValue,index] = min(distances);
4151.
416 errorOfElement = 1-abs(new(:,index)'\*d);
4171.
418 totalDistances = totalDistances+errorOfElement;
4191.
420 catchCounter = catchCounter+(errorOfElement<0.01);
4211.
422end
4231.
424ratio = 100\*catchCounter/size(original,2);
4251.
426
4271.
428%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4291.
430% I_clearDictionary
4311.
432%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4331.
434function Dictionary = I_clearDictionary(Dictionary,CoefMatrix,Data)
4351.
436T2 = 0.99;
4371.
438T1 = 3;
4391.
440K=size(Dictionary,2);
4411.
442Er=sum((Data-Dictionary\*CoefMatrix).^2,1); % remove identical atoms
4431.
444G=Dictionary'\*Dictionary; G = G-diag(diag(G));
4451.
446for jj=1:1:K,
4471.
448 if max(G(jj,:))>T2 | length(find(abs(CoefMatrix(jj,:))>1e-7))<=T1 ,
4491.
450 [val,pos]=max(Er);
4511.
452 Er(pos(1))=0;
4531.
454 Dictionary(:,jj)=Data(:,pos(1))/norm(Data(:,pos(1)));
4551.
456 G=Dictionary'\*Dictionary; G = G-diag(diag(G));
4571.
458 end;
4591.
460end;
461