Commit 81b0895c authored by Gregory C. Sharp's avatar Gregory C. Sharp

Add support for specifying multiple regularization options

parent aaef9ad6
......@@ -319,7 +319,7 @@ report_score (
)
{
Bspline_score* ssd = &bst->ssd;
Regularization_parms* reg_parms = parms->reg_parms;
const Regularization_parms* rparms = parms->regularization_parms;
Bspline_landmarks* blm = parms->blm;
int i;
......@@ -352,7 +352,7 @@ report_score (
/* First line, iterations, score, misc stats */
logfile_printf ("[%2d,%3d] ", bst->it, bst->feval);
if (reg_parms->lambda > 0
if (rparms->curvature_penalty > 0
|| blm->num_landmarks > 0
|| bst->similarity_data.size() > 1)
{
......@@ -367,7 +367,7 @@ report_score (
hack_num_vox, ssd_grad_mean, sqrt (ssd_grad_norm), total_time);
/* Second line */
if (reg_parms->lambda > 0
if (rparms->curvature_penalty > 0
|| blm->num_landmarks > 0
|| bst->similarity_data.size() > 1)
{
......@@ -385,16 +385,16 @@ report_score (
++it_mr, ++it_st;
}
if (ssd->metric_record.size() > 1
&& (reg_parms->lambda > 0 || blm->num_landmarks > 0))
&& (rparms->curvature_penalty > 0 || blm->num_landmarks > 0))
{
logfile_printf ("\n");
logfile_printf (" ");
}
if (reg_parms->lambda > 0 || blm->num_landmarks > 0) {
if (rparms->curvature_penalty > 0 || blm->num_landmarks > 0) {
/* Part 2 - regularization metric */
if (reg_parms->lambda > 0) {
if (rparms->curvature_penalty > 0) {
logfile_printf ("RM %9.3f ",
reg_parms->lambda * bst->ssd.rmetric);
rparms->curvature_penalty * bst->ssd.rmetric);
}
/* Part 3 - landmark metric */
if (blm->num_landmarks > 0) {
......@@ -402,7 +402,7 @@ report_score (
blm->landmark_stiffness * bst->ssd.lmetric);
}
/* Part 4 - timing */
if (reg_parms->lambda > 0) {
if (rparms->curvature_penalty > 0) {
logfile_printf ("[ %9.3f | %9.3f ]",
total_smetric_time, ssd->time_rmetric);
}
......@@ -418,7 +418,7 @@ bspline_score (Bspline_optimize *bod)
Bspline_state *bst = bod->get_bspline_state ();
Bspline_xform *bxf = bod->get_bspline_xform ();
Regularization_parms* reg_parms = parms->reg_parms;
const Regularization_parms* rparms = parms->regularization_parms;
Bspline_landmarks* blm = parms->blm;
/* Zero out the score for this iteration */
......@@ -469,18 +469,15 @@ bspline_score (Bspline_optimize *bod)
}
/* Compute regularization */
if (reg_parms->lambda > 0.0f) {
bst->rst.compute_score (&bst->ssd, reg_parms, bxf);
if (rparms->implementation != '\0') {
bst->rst.compute_score (&bst->ssd, rparms, bxf);
bst->ssd.total_score +=
rparms->curvature_penalty * bst->ssd.rmetric;
}
/* Compute landmark score/gradient to image score/gradient */
if (blm->num_landmarks > 0) {
bspline_landmarks_score (parms, bst, bxf);
}
/* Update total score with regularization and landmarks */
bst->ssd.total_score += reg_parms->lambda * bst->ssd.rmetric;
if (blm->num_landmarks > 0) {
bst->ssd.total_score += blm->landmark_stiffness * bst->ssd.lmetric;
}
......
......@@ -36,7 +36,7 @@ Bspline_parms::Bspline_parms ()
this->fixed_stiffness = 0;
this->reg_parms = new Regularization_parms;
this->regularization_parms = 0;
this->blm = new Bspline_landmarks;
this->rbf_radius = 0;
......@@ -47,7 +47,6 @@ Bspline_parms::Bspline_parms ()
Bspline_parms::~Bspline_parms ()
{
delete this->blm;
delete this->reg_parms;
}
void
......
......@@ -77,7 +77,7 @@ public:
float mi_moving_image_maxVal;
/* Regularization */
Regularization_parms* reg_parms;
const Regularization_parms* regularization_parms;
Volume* fixed_stiffness;
/* Landmarks */
......
......@@ -64,7 +64,7 @@ Bspline_regularize::~Bspline_regularize ()
void
Bspline_regularize::initialize (
Regularization_parms *reg_parms,
const Regularization_parms *reg_parms,
Bspline_xform* bxf
)
{
......
......@@ -5,23 +5,12 @@
#define _bspline_regularize_h_
#include "plmregister_config.h"
#include "regularization_parms.h"
#include "volume.h"
class Bspline_score;
class Bspline_xform;
class Regularization_parms
{
public:
char implementation; /* Implementation: a, b, c, etc */
float lambda; /* Smoothness weighting factor */
public:
Regularization_parms () {
this->implementation = '\0';
this->lambda = 0.0f;
}
};
class PLMREGISTER_API Bspline_regularize {
public:
SMART_POINTER_SUPPORT (Bspline_regularize);
......@@ -30,7 +19,7 @@ public:
~Bspline_regularize ();
public:
/* all methods */
Regularization_parms *reg_parms;
const Regularization_parms *reg_parms;
Bspline_xform *bxf;
Volume* fixed_stiffness;
......@@ -55,7 +44,7 @@ public:
double* cond;
public:
void initialize (
Regularization_parms* reg_parms,
const Regularization_parms* reg_parms,
Bspline_xform* bxf
);
void compute_score (
......
......@@ -359,9 +359,9 @@ region_smoothness_omp (
/* ------------------------------------------------ */
/* dS/dp = 2Vp operation */
sets[3*j+0] += 2 * reg_parms->lambda * X[j];
sets[3*j+1] += 2 * reg_parms->lambda * Y[j];
sets[3*j+2] += 2 * reg_parms->lambda * Z[j];
sets[3*j+0] += 2 * reg_parms->curvature_penalty * X[j];
sets[3*j+1] += 2 * reg_parms->curvature_penalty * Y[j];
sets[3*j+2] += 2 * reg_parms->curvature_penalty * Z[j];
}
return S;
......@@ -396,9 +396,9 @@ region_smoothness (
/* ------------------------------------------------ */
/* dS/dp = 2Vp operation */
bspline_score->total_grad[3*knots[j]+0] += 2 * reg_parms->lambda * X[j];
bspline_score->total_grad[3*knots[j]+1] += 2 * reg_parms->lambda * Y[j];
bspline_score->total_grad[3*knots[j]+2] += 2 * reg_parms->lambda * Z[j];
bspline_score->total_grad[3*knots[j]+0] += 2 * reg_parms->curvature_penalty * X[j];
bspline_score->total_grad[3*knots[j]+1] += 2 * reg_parms->curvature_penalty * Y[j];
bspline_score->total_grad[3*knots[j]+2] += 2 * reg_parms->curvature_penalty * Z[j];
}
bspline_score->rmetric += S;
......
......@@ -431,15 +431,13 @@ Bspline_regularize::compute_score_semi_analytic (
grad_score = 0;
num_vox = bxf->roi_dim[0] * bxf->roi_dim[1] * bxf->roi_dim[2];
grad_coeff = parms->lambda / num_vox;
grad_coeff = parms->curvature_penalty / num_vox;
Plm_timer* timer = new Plm_timer;
timer->start ();
bscore->rmetric = 0.;
//printf("---- YOUNG MODULUS %f\n", parms->lambda);
for (rk = 0, fk = bxf->roi_offset[2]; rk < bxf->roi_dim[2]; rk++, fk++) {
p[2] = rk / bxf->vox_per_rgn[2];
q[2] = rk % bxf->vox_per_rgn[2];
......@@ -497,12 +495,9 @@ Bspline_regularize::compute_score_semi_analytic (
}
bscore->time_rmetric = timer->report ();
//raw_score = grad_score / num_vox;
grad_score *= (parms->lambda / num_vox);
//printf (" GRAD_COST %.4f RAW_GRAD %.4f [%.3f secs]\n", grad_score, raw_score, interval);
grad_score *= (parms->curvature_penalty / num_vox);
bscore->rmetric += grad_score;
}
//printf ("SCORE=%.4f\n", bscore->score);
delete timer;
}
......
......@@ -268,33 +268,40 @@ Bspline_stage::initialize ()
}
/* Regularization */
parms->reg_parms->lambda = stage->regularization_lambda;
switch (stage->regularization_type) {
parms->regularization_parms = &stage->regularization_parms;
switch (parms->regularization_parms->regularization_type) {
case REGULARIZATION_NONE:
parms->reg_parms->lambda = 0.0f;
parms->regularization_parms->implementation = '\0';
break;
case REGULARIZATION_BSPLINE_ANALYTIC:
if (stage->threading_type == THREADING_CPU_SINGLE) {
parms->reg_parms->implementation = 'b';
parms->regularization_parms->implementation = 'b';
} else {
parms->reg_parms->implementation = 'c';
parms->regularization_parms->implementation = 'c';
}
break;
case REGULARIZATION_BSPLINE_SEMI_ANALYTIC:
parms->reg_parms->implementation = 'd';
parms->regularization_parms->implementation = 'd';
break;
case REGULARIZATION_BSPLINE_NUMERIC:
parms->reg_parms->implementation = 'a';
parms->regularization_parms->implementation = 'a';
break;
default:
print_and_exit ("Undefined regularization type in gpuit_bspline\n");
}
if (stage->regularization_lambda != 0) {
parms->reg_parms->lambda = stage->regularization_lambda;
if (parms->regularization_parms->total_displacement_penalty == 0
&& parms->regularization_parms->diffusion_penalty == 0
&& parms->regularization_parms->curvature_penalty == 0
&& parms->regularization_parms->linear_elastic_penalty == 0
&& parms->regularization_parms->third_order_penalty == 0)
{
parms->regularization_parms->implementation = '\0';
}
if (parms->regularization_parms->implementation != '\0') {
logfile_printf ("Regularization: flavor = %c lambda = %f\n",
parms->regularization_parms->implementation,
parms->regularization_parms->curvature_penalty);
}
logfile_printf ("Regularization: flavor = %c lambda = %f\n",
parms->reg_parms->implementation,
parms->reg_parms->lambda);
/* Mutual information histograms */
parms->mi_hist_type = stage->mi_hist_type;
......
......@@ -81,7 +81,7 @@ Bspline_state::initialize (
Bspline_xform *bxf,
Bspline_parms *parms)
{
Regularization_parms* reg_parms = parms->reg_parms;
const Regularization_parms* rparms = parms->regularization_parms;
Bspline_regularize* rst = &this->rst;
Bspline_landmarks* blm = parms->blm;
......@@ -96,9 +96,9 @@ Bspline_state::initialize (
this->ssd.set_num_coeff (bxf->num_coeff);
if (reg_parms->lambda > 0.0f) {
if (rparms->curvature_penalty > 0.0f) {
rst->fixed_stiffness = parms->fixed_stiffness;
rst->initialize (reg_parms, bxf);
rst->initialize (rparms, bxf);
}
/* Initialize MI histograms */
......
......@@ -485,26 +485,30 @@ Registration_parms::set_key_value (
else if (key == "regularization")
{
if (!section_stage) goto key_only_allowed_in_section_stage;
Regularization_type& rtype
= stage->regularization_parms.regularization_type;
if (val == "none") {
stage->regularization_type = REGULARIZATION_NONE;
rtype = REGULARIZATION_NONE;
}
else if (val == "analytic") {
stage->regularization_type = REGULARIZATION_BSPLINE_ANALYTIC;
rtype = REGULARIZATION_BSPLINE_ANALYTIC;
}
else if (val == "semi_analytic") {
stage->regularization_type = REGULARIZATION_BSPLINE_SEMI_ANALYTIC;
rtype = REGULARIZATION_BSPLINE_SEMI_ANALYTIC;
}
else if (val == "numeric") {
stage->regularization_type = REGULARIZATION_BSPLINE_NUMERIC;
rtype = REGULARIZATION_BSPLINE_NUMERIC;
}
else {
goto error_exit;
}
}
else if (key == "regularization_lambda"
else if (key == "curvature_penalty"
|| key == "regularization_lambda"
|| key == "young_modulus") {
if (!section_stage) goto key_only_allowed_in_section_stage;
if (sscanf (val.c_str(), "%f", &stage->regularization_lambda) != 1) {
if (sscanf (val.c_str(), "%f",
&stage->regularization_parms.curvature_penalty) != 1) {
goto error_exit;
}
}
......
/* -----------------------------------------------------------------------
See COPYRIGHT.TXT and LICENSE.TXT for copyright and license information
----------------------------------------------------------------------- */
#ifndef _regularization_parms_h_
#define _regularization_parms_h_
#include "plmregister_config.h"
enum Regularization_type {
REGULARIZATION_NONE,
REGULARIZATION_BSPLINE_ANALYTIC,
REGULARIZATION_BSPLINE_SEMI_ANALYTIC,
REGULARIZATION_BSPLINE_NUMERIC
};
class Regularization_parms
{
public:
Regularization_type regularization_type;
mutable char implementation;
float total_displacement_penalty;
float diffusion_penalty;
float curvature_penalty;
float linear_elastic_penalty;
float third_order_penalty;
public:
Regularization_parms () {
this->regularization_type = REGULARIZATION_BSPLINE_ANALYTIC;
this->implementation = '\0';
this->total_displacement_penalty = 0.f;
this->diffusion_penalty = 0.f;
this->curvature_penalty = 0.f;
this->linear_elastic_penalty = 0.f;
this->third_order_penalty = 0.f;
}
};
#endif
......@@ -49,10 +49,7 @@ Stage_parms::Stage_parms ()
alg_flavor = 0;
threading_type = THREADING_CPU_OPENMP;
gpuid = 0;
/* Similarity metric */
regularization_type = REGULARIZATION_BSPLINE_ANALYTIC;
demons_gradient_type = SYMMETRIC;
regularization_lambda = 0.0f;
/* Regularization - default constructor is ok */
/* Image resample */
resample_type = RESAMPLE_AUTO;
resample_rate_fixed[0] = 4;
......@@ -81,12 +78,12 @@ Stage_parms::Stage_parms ()
rotation_scale_factor = 1;
scaling_scale_factor = 10;
/*OnePlusOne evolutionary optimizer*/
opo_initial_search_rad=1.01;
opo_epsilon=1e-7;
opo_initial_search_rad = 1.01;
opo_epsilon = 1e-7;
/*frpr optimizer*/
frpr_step_tol=0.000001;
frpr_step_length=5.0;
frpr_max_line_its=100;
frpr_step_tol = 0.000001;
frpr_step_length = 5.0;
frpr_max_line_its = 100;
/* Quaternion optimizer */
learn_rate = 0.01 ;
/* Mattes mutual information */
......@@ -98,17 +95,18 @@ Stage_parms::Stage_parms ()
/* MI threshold values */
/*Setting values to zero by default. In this case minVal and
maxVal will be calculated from image*/
mi_fixed_image_minVal=0;
mi_fixed_image_maxVal=0;
mi_moving_image_minVal=0;
mi_moving_image_maxVal=0;
mi_fixed_image_minVal = 0;
mi_fixed_image_maxVal = 0;
mi_moving_image_minVal = 0;
mi_moving_image_maxVal = 0;
/* ITK & GPUIT demons */
demons_std = 1.0;
demons_std_update_field = 1.0;
demons_smooth_deformation_field =true;
demons_smooth_update_field=false;
demons_step_length = 2.0;
num_approx_terms_log_demons=2;
demons_smooth_update_field = false;
demons_smooth_deformation_field = true;
num_approx_terms_log_demons = 2;
demons_gradient_type = SYMMETRIC;
/* GPUIT demons */
demons_acceleration = 1.0;
demons_homogenization = 1.0;
......@@ -122,9 +120,9 @@ Stage_parms::Stage_parms ()
grid_spac[1] = 20.;
grid_spac[2] = 20.;
histoeq = false; // by default, don't do it
thresh_mean_intensity=false;
num_matching_points=500;
num_hist_levels=1000;
thresh_mean_intensity = false;
num_matching_points = 500;
num_hist_levels = 1000;
/* Native grid search */
gridsearch_strategy = GRIDSEARCH_STRATEGY_AUTO;
gridsearch_min_overlap[0] = 0.5;
......@@ -165,9 +163,8 @@ Stage_parms::Stage_parms (const Stage_parms& s)
alg_flavor = s.alg_flavor;
threading_type = s.threading_type;
gpuid = s.gpuid;
/* Similarity metric */
regularization_type = s.regularization_type;
regularization_lambda = s.regularization_lambda;
/* Regularization */
regularization_parms = s.regularization_parms;
/* Image resample */
resample_type = s.resample_type;
resample_rate_fixed[0] = s.resample_rate_fixed[0];
......
......@@ -15,6 +15,7 @@
#include "plm_image_type.h"
#include "plm_return_code.h"
#include "process_parms.h"
#include "regularization_parms.h"
#include "similarity_metric_type.h"
#include "threading.h"
......@@ -71,13 +72,6 @@ enum Resample_type {
RESAMPLE_DIM /* res_dim */
};
enum Regularization_type {
REGULARIZATION_NONE,
REGULARIZATION_BSPLINE_ANALYTIC,
REGULARIZATION_BSPLINE_SEMI_ANALYTIC,
REGULARIZATION_BSPLINE_NUMERIC
};
enum Demons_gradient_type {
SYMMETRIC,
FIXED_IMAGE,
......@@ -125,8 +119,7 @@ public:
Threading threading_type;
int gpuid; /* Sets GPU to use for multi-gpu machines */
/* Regularization */
Regularization_type regularization_type;
float regularization_lambda;
Regularization_parms regularization_parms;
/* Image resampling */
/* The units of fixed_resampling_rate are: voxels for res_vox,
mm for res_mm, pct for res_pct, voxels for res_dim */
......@@ -175,7 +168,8 @@ public:
float demons_std;
float demons_std_update_field;
float demons_step_length;
bool demons_smooth_update_field, demons_smooth_deformation_field;
bool demons_smooth_update_field;
bool demons_smooth_deformation_field;
unsigned int num_approx_terms_log_demons;
bool histoeq; // histogram matching flag
bool thresh_mean_intensity;
......
......@@ -22,10 +22,6 @@ set (BRAGG_CURVE_SRC
set (SOBP_SRC
sobp_main.cxx
)
set (BSPLINE_SRC
bspline_main.cxx
bspline_opts.cxx
)
set (CHECK_GRAD_SRC
check_grad.cxx
)
......@@ -263,8 +259,6 @@ if (QT4_FOUND)
"${PLASTIMATCH_LDFLAGS}" ${INSTALL_IF_NOT_DEBIAN})
endif()
plm_add_executable_v3 (bspline "${BSPLINE_SRC}"
"" "${PLASTIMATCH_LIBS}" "${PLASTIMATCH_LDFLAGS}" ${INSTALL_NEVER})
plm_add_executable_v3 (check_grad "${CHECK_GRAD_SRC}"
"" "${PLASTIMATCH_LIBS}" "${PLASTIMATCH_LDFLAGS}" ${INSTALL_NEVER})
plm_add_executable_v3 (demons "${DEMONS_SRC}"
......
......@@ -58,7 +58,7 @@ bspline_opts_parse_args (Bspline_options* options, int argc, char* argv[])
{
int i, rc;
Bspline_parms* parms = &options->parms;
Regularization_parms* reg_parms = parms->reg_parms;
Regularization_parms* reg_parms = parms->regularization_parms;
Bspline_landmarks* blm = parms->blm;
for (i = 1; i < argc; i++) {
......@@ -145,7 +145,7 @@ bspline_opts_parse_args (Bspline_options* options, int argc, char* argv[])
exit(1);
}
i++;
rc = sscanf (argv[i], "%g", &reg_parms->lambda);
rc = sscanf (argv[i], "%g", &reg_parms->curvature_penalty);
if (rc != 1) {
print_usage ();
}
......@@ -320,7 +320,7 @@ bspline_opts_parse_args (Bspline_options* options, int argc, char* argv[])
exit(1);
}
i++;
rc = sscanf (argv[i], "%g", &reg_parms->lambda);
rc = sscanf (argv[i], "%g", &reg_parms->curvature_penalty);
if (rc != 1) {
print_usage ();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment