1 #ifndef COORDINATE_DESCENT_H 2 #define COORDINATE_DESCENT_H 11 #include <type_traits> 14 #include <eigen3/Eigen/Dense> 20 #include "../../Generic/debug.hpp" 21 #include "../../Generic/generics.hpp" 22 #include "../solver.hpp" 23 #include "../screeningsolver.hpp" 27 template <
typename T,
typename Base =
internal::Solver<T> >
32 const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Y,
33 const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Beta_0 );
37 Eigen::Matrix< T, Eigen::Dynamic, 1 > update_rule(
38 const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic >& X,
39 const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Y,
40 const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Beta_0,
44 std::vector<T> inverse_norms;
48 template <
typename T,
typename Base >
50 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> &X,
51 const Eigen::Matrix<T, Eigen::Dynamic, 1> &Y,
52 const Eigen::Matrix<T, Eigen::Dynamic, 1> &Beta_0 ) {
56 inverse_norms.reserve( Beta_0.size() );
58 for(
int i = 0; i < Beta_0.size() ; i++ ) {
60 T X_i_norm = X.col( i ).squaredNorm();
62 T inverse_norm = ( X_i_norm == 0 )?( 0.0 ):(
static_cast<T
>(1)/X_i_norm );
63 inverse_norms.push_back( inverse_norm );
68 template <
typename T,
typename Base >
71 template <
typename T,
typename Base >
73 const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic >& X,
74 const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Y,
75 const Eigen::Matrix< T, Eigen::Dynamic, 1 >& Beta_0,
78 Eigen::Matrix< T, Eigen::Dynamic, 1 > Beta = Beta_0;
79 Eigen::Matrix< T, Eigen::Dynamic, 1 > Residual = Y - X*Beta_0;
82 for(
int i = 0; i < Beta.size() ; i++ ) {
84 Eigen::Matrix< T, Eigen::Dynamic, 1 > X_i = X.col( i );
86 T inverse_norm = inverse_norms[i];
88 if( Beta( i ) != static_cast<T>(0) ) {
89 Residual = Residual + X_i*Beta( i );
93 T threshold = lambda / (
static_cast<T
>(2) ) * inverse_norm;
94 T elem = inverse_norm*X_i.transpose()*Residual;
96 Beta( i ) = soft_threshold<T>( elem, threshold );
98 if( Beta( i ) !=
static_cast<T
>(0) ) {
99 Residual = Residual - X_i*Beta( i );
104 DEBUG_PRINT(
"Norm Squared of updated Beta: " << Beta.squaredNorm() );
111 #endif // COORDINATE_DESCENT_H
#define DEBUG_PRINT(x,...)