HDIM  1.0.0
Packages for High Dimensional Linear Regression
ista.hpp
1 #ifndef ISTA_H
2 #define ISTA_H
3 
4 // C System-Headers
5 //
6 // C++ System headers
7 #include <functional>
8 // Eigen Headers
9 //
10 // Boost Headers
11 //
12 // SPAMS Headers
13 //
14 // OpenMP Headers
15 //
16 // Project Specific Headers
17 #include "../../../Generic/generics.hpp"
18 #include "../../../Generic/debug.hpp"
19 #include "../subgradient_descent.hpp"
20 
21 namespace hdim {
22 
23 //template< typename T >
24 //using MatrixT = Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic >;
25 
26 //template< typename T >
27 //using VectorT = Eigen::Matrix< T, Eigen::Dynamic, 1 >;
28 
29 template < typename T, typename Base = internal::Solver< T > >
33 class ISTA : public internal::SubGradientSolver<T,Base> {
34 
35  public:
36  ISTA( T L_0 = 0.1 );
37 
38  protected:
39  Eigen::Matrix< T, Eigen::Dynamic, 1 > update_rule(
40  const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic >& X,
41  const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Y,
42  const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Beta_0,
43  T lambda );
44 
45  private:
46  const T eta = 1.5;
47  T L = static_cast<T>( 0 );
48 
49 };
50 
51 template < typename T, typename Base >
53 
54 #ifdef DEBUG
55 template < typename T, typename Base >
56 Eigen::Matrix< T, Eigen::Dynamic, 1 > ISTA<T,Base>::update_rule(
57  const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic >& X,
58  const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Y,
59  const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Beta_0,
60  T lambda ) {
61 
62  Eigen::Matrix< T, Eigen::Dynamic, 1 > Beta = Beta_0;
63 
64  unsigned int counter = 0;
66 
67  Eigen::Matrix< T, Eigen::Dynamic, 1 > Beta_temp = internal::SubGradientSolver<T,Base>::update_beta_ista( X, Y, Beta, L, lambda );
68 
69  counter++;
70  DEBUG_PRINT( "Backtrace iteration: " << counter );
71 
72  while( ( internal::SubGradientSolver<T,Base>::f_beta( X, Y, Beta_temp ) > internal::SubGradientSolver<T,Base>::f_beta_tilda( X, Y, Beta_temp, Beta, L ) ) ) {
73 
74  counter++;
75  DEBUG_PRINT( "Backtrace iteration: " << counter );
76 
77  L*= eta;
78  Beta_temp = internal::SubGradientSolver<T,Base>::update_beta_ista( X, Y, Beta, L, lambda );
79 
80  }
81 
82  return internal::SubGradientSolver<T,Base>::update_beta_ista( X, Y, Beta, L, lambda );
83 }
84 #else
85 template < typename T, typename Base >
86 Eigen::Matrix< T, Eigen::Dynamic, 1 > ISTA<T,Base>::update_rule(
87  const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic >& X,
88  const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Y,
89  const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Beta,
90  T lambda ) {
91 
93 
94  Eigen::Matrix< T, Eigen::Dynamic, 1 > f_grad = 2.0*( X.transpose()*( X*Beta - Y ) );
95  Eigen::Matrix< T, Eigen::Dynamic, 1 > Beta_temp = ( Beta - (1.0/L)*f_grad ).unaryExpr( SoftThres<T>( lambda/L ) );
96 
97  T f_beta = ( X*Beta_temp - Y ).squaredNorm();
98 
99  Eigen::Matrix< T, Eigen::Dynamic, 1 > f_part = X*Beta - Y;
100  T taylor_term_0 = f_part.squaredNorm();
101 
102  Eigen::Matrix< T, Eigen::Dynamic, 1 > beta_diff = ( Beta_temp - Beta );
103 
104  T taylor_term_1 = f_grad.transpose()*beta_diff;
105 
106  T taylor_term_2 = L/2.0*beta_diff.squaredNorm();
107 
108  T f_beta_tilde = taylor_term_0 + taylor_term_1 + taylor_term_2;
109 
110  while( f_beta > f_beta_tilde ) {
111 
112  L*= eta;
113 
114  Beta_temp = ( Beta - (1.0/L)*f_grad ).unaryExpr( SoftThres<T>( lambda/L ) );
115 
116  f_beta = ( X*Beta_temp - Y ).squaredNorm();;
117 
118  beta_diff = ( Beta_temp - Beta );
119  taylor_term_1 = f_grad.transpose()*beta_diff;
120  taylor_term_2 = L/2.0*beta_diff.squaredNorm();
121 
122  f_beta_tilde = taylor_term_0 + taylor_term_1 + taylor_term_2;
123 
124  }
125 
126  return ( Beta - 1.0/L*f_grad ).unaryExpr( SoftThres<T>( lambda/L ) );
127 }
128 #endif
129 
130 }
131 
132 
133 #endif // ISTA_H
Definition: fos.hpp:18
Abstract base class for Sub-Gradient Descent algorithms ,such as ISTA and FISTA, with backtracking li...
Soft Threshold functor used to apply hdim::soft_threshold to each element in a matrix or vector...
Definition: generics.hpp:288
#define DEBUG_PRINT(x,...)
Definition: debug.hpp:73
Run the Iterative Shrinking and Thresholding Algorthim.
Definition: ista.hpp:33