|  | @@ -29,6 +29,7 @@
 | 
											
												
													
														|  |  // Author: sameeragarwal@google.com (Sameer Agarwal)
 |  |  // Author: sameeragarwal@google.com (Sameer Agarwal)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  #include "gtest/gtest.h"
 |  |  #include "gtest/gtest.h"
 | 
											
												
													
														|  | 
 |  | +#include "ceres/autodiff_cost_function.h"
 | 
											
												
													
														|  |  #include "ceres/linear_solver.h"
 |  |  #include "ceres/linear_solver.h"
 | 
											
												
													
														|  |  #include "ceres/parameter_block.h"
 |  |  #include "ceres/parameter_block.h"
 | 
											
												
													
														|  |  #include "ceres/problem_impl.h"
 |  |  #include "ceres/problem_impl.h"
 | 
											
										
											
												
													
														|  | @@ -560,5 +561,69 @@ TEST(SolverImpl, CreateLinearSolverNormalOperation) {
 | 
											
												
													
														|  |    EXPECT_TRUE(solver.get() != NULL);
 |  |    EXPECT_TRUE(solver.get() != NULL);
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +struct QuadraticCostFunction {
 | 
											
												
													
														|  | 
 |  | +  template <typename T> bool operator()(const T* const x,
 | 
											
												
													
														|  | 
 |  | +                                        T* residual) const {
 | 
											
												
													
														|  | 
 |  | +    residual[0] = T(5.0) - *x;
 | 
											
												
													
														|  | 
 |  | +    return true;
 | 
											
												
													
														|  | 
 |  | +  }
 | 
											
												
													
														|  | 
 |  | +};
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +struct RememberingCallback : public IterationCallback {
 | 
											
												
													
														|  | 
 |  | +  RememberingCallback(double *x) : calls(0), x(x) {}
 | 
											
												
													
														|  | 
 |  | +  virtual ~RememberingCallback() {}
 | 
											
												
													
														|  | 
 |  | +  virtual CallbackReturnType operator()(const IterationSummary& summary) {
 | 
											
												
													
														|  | 
 |  | +    x_values.push_back(*x);
 | 
											
												
													
														|  | 
 |  | +    return SOLVER_CONTINUE;
 | 
											
												
													
														|  | 
 |  | +  }
 | 
											
												
													
														|  | 
 |  | +  int calls;
 | 
											
												
													
														|  | 
 |  | +  double *x;
 | 
											
												
													
														|  | 
 |  | +  vector<double> x_values;
 | 
											
												
													
														|  | 
 |  | +};
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +TEST(SolverImpl, UpdateStateEveryIterationOption) {
 | 
											
												
													
														|  | 
 |  | +  double x = 50.0;
 | 
											
												
													
														|  | 
 |  | +  const double original_x = x;
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  scoped_ptr<CostFunction> cost_function(
 | 
											
												
													
														|  | 
 |  | +      new AutoDiffCostFunction<QuadraticCostFunction, 1, 1>(
 | 
											
												
													
														|  | 
 |  | +          new QuadraticCostFunction));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  Problem::Options problem_options;
 | 
											
												
													
														|  | 
 |  | +  problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
 | 
											
												
													
														|  | 
 |  | +  Problem problem(problem_options);
 | 
											
												
													
														|  | 
 |  | +  problem.AddResidualBlock(cost_function.get(), NULL, &x);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  Solver::Options options;
 | 
											
												
													
														|  | 
 |  | +  options.linear_solver_type = DENSE_QR;
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  RememberingCallback callback(&x);
 | 
											
												
													
														|  | 
 |  | +  options.callbacks.push_back(&callback);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  Solver::Summary summary;
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  int num_iterations;
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  // First try: no updating.
 | 
											
												
													
														|  | 
 |  | +  SolverImpl::Solve(options, &problem, &summary);
 | 
											
												
													
														|  | 
 |  | +  num_iterations = summary.num_successful_steps +
 | 
											
												
													
														|  | 
 |  | +                   summary.num_unsuccessful_steps;
 | 
											
												
													
														|  | 
 |  | +  EXPECT_GT(num_iterations, 1);
 | 
											
												
													
														|  | 
 |  | +  for (int i = 0; i < callback.x_values.size(); ++i) {
 | 
											
												
													
														|  | 
 |  | +    EXPECT_EQ(50.0, callback.x_values[i]);
 | 
											
												
													
														|  | 
 |  | +  }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +  // Second try: with updating
 | 
											
												
													
														|  | 
 |  | +  x = 50.0;
 | 
											
												
													
														|  | 
 |  | +  options.update_state_every_iteration = true;
 | 
											
												
													
														|  | 
 |  | +  callback.x_values.clear();
 | 
											
												
													
														|  | 
 |  | +  SolverImpl::Solve(options, &problem, &summary);
 | 
											
												
													
														|  | 
 |  | +  num_iterations = summary.num_successful_steps +
 | 
											
												
													
														|  | 
 |  | +                   summary.num_unsuccessful_steps;
 | 
											
												
													
														|  | 
 |  | +  EXPECT_GT(num_iterations, 1);
 | 
											
												
													
														|  | 
 |  | +  EXPECT_EQ(original_x, callback.x_values[0]);
 | 
											
												
													
														|  | 
 |  | +  EXPECT_NE(original_x, callback.x_values[1]);
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  }  // namespace internal
 |  |  }  // namespace internal
 | 
											
												
													
														|  |  }  // namespace ceres
 |  |  }  // namespace ceres
 |