選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

solver.cpp 11 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. //
  2. // LolRemez - Remez algorithm implementation
  3. //
  4. // Copyright © 2005—2015 Sam Hocevar <sam@hocevar.net>
  5. //
  6. // This program is free software. It comes without any warranty, to
  7. // the extent permitted by applicable law. You can redistribute it
  8. // and/or modify it under the terms of the Do What the Fuck You Want
  9. // to Public License, Version 2, as published by the WTFPL Task Force.
  10. // See http://www.wtfpl.net/ for more details.
  11. //
  12. #if HAVE_CONFIG_H
  13. # include "config.h"
  14. #endif
  15. #include <functional>
  16. #include <lol/engine.h>
  17. #include <lol/math/real.h>
  18. #include <lol/math/polynomial.h>
  19. #include "matrix.h"
  20. #include "solver.h"
  21. using lol::real;
  22. remez_solver::remez_solver(int order, int decimals)
  23. : m_order(order),
  24. m_decimals(decimals),
  25. m_has_weight(false)
  26. {
  27. /* Spawn 4 worker threads */
  28. for (int i = 0; i < 4; ++i)
  29. {
  30. auto th = new thread(std::bind(&remez_solver::worker_thread, this));
  31. m_workers.push(th);
  32. }
  33. }
  34. remez_solver::~remez_solver()
  35. {
  36. /* Signal worker threads to quit, wait for worker threads to answer,
  37. * and kill worker threads. */
  38. for (auto worker : m_workers)
  39. UNUSED(worker), m_questions.push(-1);
  40. for (auto worker : m_workers)
  41. UNUSED(worker), m_answers.pop();
  42. for (auto worker : m_workers)
  43. delete worker;
  44. }
  45. void remez_solver::run(real a, real b, char const *func, char const *weight)
  46. {
  47. m_func.parse(func);
  48. if (weight)
  49. {
  50. m_weight.parse(weight);
  51. m_has_weight = true;
  52. }
  53. m_k1 = (b + a) / 2;
  54. m_k2 = (b - a) / 2;
  55. m_epsilon = pow((real)10, (real)-(m_decimals + 2));
  56. remez_init();
  57. print_poly();
  58. for (int n = 0; ; n++)
  59. {
  60. real old_error = m_error;
  61. find_extrema();
  62. remez_step();
  63. if (m_error >= (real)0
  64. && fabs(m_error - old_error) < m_error * m_epsilon)
  65. break;
  66. print_poly();
  67. find_zeroes();
  68. }
  69. print_poly();
  70. }
  71. /*
  72. * This is basically the first Remez step: we solve a system of
  73. * order N+1 and get a good initial polynomial estimate.
  74. */
  75. void remez_solver::remez_init()
  76. {
  77. /* m_order + 1 zeroes of the error function */
  78. m_zeroes.resize(m_order + 1);
  79. /* m_order + 1 zeroes to find */
  80. m_zeroes_state.resize(m_order + 1);
  81. /* m_order + 2 control points */
  82. m_control.resize(m_order + 2);
  83. /* m_order extrema to find */
  84. m_extrema_state.resize(m_order);
  85. /* Initial estimates for the x_i where the error will be zero and
  86. * precompute f(x_i). */
  87. array<real> fxn;
  88. for (int i = 0; i < m_order + 1; i++)
  89. {
  90. m_zeroes[i] = (real)(2 * i - m_order) / (real)(m_order + 1);
  91. fxn.push(eval_func(m_zeroes[i]));
  92. }
  93. /* We build a matrix of Chebyshev evaluations: row i contains the
  94. * evaluations of x_i for polynomial order n = 0, 1, ... */
  95. linear_system<real> system(m_order + 1);
  96. for (int n = 0; n < m_order + 1; n++)
  97. {
  98. auto p = polynomial<real>::chebyshev(n);
  99. for (int i = 0; i < m_order + 1; i++)
  100. system[i][n] = p.eval(m_zeroes[i]);
  101. }
  102. /* Solve the system */
  103. system = system.inverse();
  104. /* Compute new Chebyshev estimate */
  105. m_estimate = polynomial<real>();
  106. for (int n = 0; n < m_order + 1; n++)
  107. {
  108. real weight = 0;
  109. for (int i = 0; i < m_order + 1; i++)
  110. weight += system[n][i] * fxn[i];
  111. m_estimate += weight * polynomial<real>::chebyshev(n);
  112. }
  113. }
  114. /*
  115. * Every subsequent iteration of the Remez algorithm: we solve a system
  116. * of order N+2 to both refine the estimate and compute the error.
  117. */
  118. void remez_solver::remez_step()
  119. {
  120. Timer t;
  121. /* Pick up x_i where error will be 0 and compute f(x_i) */
  122. array<real> fxn;
  123. for (int i = 0; i < m_order + 2; i++)
  124. fxn.push(eval_func(m_control[i]));
  125. /* We build a matrix of Chebyshev evaluations: row i contains the
  126. * evaluations of x_i for polynomial order n = 0, 1, ... */
  127. linear_system<real> system(m_order + 2);
  128. for (int n = 0; n < m_order + 1; n++)
  129. {
  130. auto p = polynomial<real>::chebyshev(n);
  131. for (int i = 0; i < m_order + 2; i++)
  132. system[i][n] = p.eval(m_control[i]);
  133. }
  134. /* The last line of the system is the oscillating error */
  135. for (int i = 0; i < m_order + 2; i++)
  136. {
  137. real error = fabs(eval_weight(m_control[i]));
  138. system[i][m_order + 1] = (i & 1) ? error : -error;
  139. }
  140. /* Solve the system */
  141. system = system.inverse();
  142. /* Compute new polynomial estimate */
  143. m_estimate = polynomial<real>();
  144. for (int n = 0; n < m_order + 1; n++)
  145. {
  146. real weight = 0;
  147. for (int i = 0; i < m_order + 2; i++)
  148. weight += system[n][i] * fxn[i];
  149. m_estimate += weight * polynomial<real>::chebyshev(n);
  150. }
  151. /* Compute the error (FIXME: unused?) */
  152. real error = 0;
  153. for (int i = 0; i < m_order + 2; i++)
  154. error += system[m_order + 1][i] * fxn[i];
  155. using std::printf;
  156. printf(" -:- timing for inversion: %f ms\n", t.Get() * 1000.f);
  157. }
  158. /*
  159. * Find m_order + 1 zeroes of the error function. No need to compute the
  160. * relative error: its zeroes are at the same place as the absolute error!
  161. *
  162. * The algorithm used here is naïve regula falsi. It still performs a lot
  163. * better than the midpoint algorithm.
  164. */
  165. void remez_solver::find_zeroes()
  166. {
  167. Timer t;
  168. static real const limit = ldexp((real)1, -500);
  169. static real const zero = (real)0;
  170. /* Initialise an [a,b] bracket for each zero we try to find */
  171. for (int i = 0; i < m_order + 1; i++)
  172. {
  173. point &a = m_zeroes_state[i].m1;
  174. point &b = m_zeroes_state[i].m2;
  175. a.x = m_control[i];
  176. a.err = eval_estimate(a.x) - eval_func(a.x);
  177. b.x = m_control[i + 1];
  178. b.err = eval_estimate(b.x) - eval_func(b.x);
  179. m_questions.push(i);
  180. }
  181. /* Watch all brackets for updates from worker threads */
  182. for (int finished = 0; finished < m_order + 1; )
  183. {
  184. int i = m_answers.pop();
  185. point &a = m_zeroes_state[i].m1;
  186. point &b = m_zeroes_state[i].m2;
  187. point &c = m_zeroes_state[i].m3;
  188. if (c.err == zero || fabs(a.x - b.x) <= limit)
  189. {
  190. m_zeroes[i] = c.x;
  191. ++finished;
  192. continue;
  193. }
  194. m_questions.push(i);
  195. }
  196. using std::printf;
  197. printf(" -:- timing for zeroes: %f ms\n", t.Get() * 1000.f);
  198. }
  199. /*
  200. * Find m_order extrema of the error function. We maximise the relative
  201. * error, since its extrema are at slightly different locations than the
  202. * absolute error’s.
  203. *
  204. * The algorithm used here is successive parabolic interpolation. FIXME: we
  205. * could use Brent’s method instead, which combines parabolic interpolation
  206. * and golden ratio search and has superlinear convergence.
  207. */
  208. void remez_solver::find_extrema()
  209. {
  210. Timer t;
  211. m_control[0] = -1;
  212. m_control[m_order + 1] = 1;
  213. m_error = 0;
  214. /* Initialise an [a,b,c] bracket for each extremum we try to find */
  215. for (int i = 0; i < m_order; i++)
  216. {
  217. point &a = m_extrema_state[i].m1;
  218. point &b = m_extrema_state[i].m2;
  219. point &c = m_extrema_state[i].m3;
  220. a.x = m_zeroes[i];
  221. b.x = m_zeroes[i + 1];
  222. c.x = a.x + (b.x - a.x) * real(rand(0.4f, 0.6f));
  223. a.err = eval_error(a.x);
  224. b.err = eval_error(b.x);
  225. c.err = eval_error(c.x);
  226. m_questions.push(i + 1000);
  227. }
  228. /* Watch all brackets for updates from worker threads */
  229. for (int finished = 0; finished < m_order; )
  230. {
  231. int i = m_answers.pop();
  232. point &a = m_extrema_state[i - 1000].m1;
  233. point &b = m_extrema_state[i - 1000].m2;
  234. point &c = m_extrema_state[i - 1000].m3;
  235. if (b.x - a.x <= m_epsilon)
  236. {
  237. m_control[i - 1000 + 1] = c.x;
  238. if (c.err > m_error)
  239. m_error = c.err;
  240. ++finished;
  241. continue;
  242. }
  243. m_questions.push(i);
  244. }
  245. using std::printf;
  246. printf(" -:- timing for extrema: %f ms\n", t.Get() * 1000.f);
  247. printf(" -:- error: ");
  248. m_error.print(m_decimals);
  249. printf("\n");
  250. }
  251. void remez_solver::print_poly()
  252. {
  253. /* Transform our polynomial in the [-1..1] range into a polynomial
  254. * in the [a..b] range by composing it with q:
  255. * q(x) = 2x / (b-a) - (b+a) / (b-a) */
  256. polynomial<real> q ({ -m_k1 / m_k2, real(1) / m_k2 });
  257. polynomial<real> r = m_estimate.eval(q);
  258. using std::printf;
  259. printf("\n");
  260. for (int j = 0; j < m_order + 1; j++)
  261. {
  262. if (j)
  263. printf(" + x**%i * ", j);
  264. r[j].print(m_decimals);
  265. }
  266. printf("\n\n");
  267. }
  268. real remez_solver::eval_estimate(real const &x)
  269. {
  270. return m_estimate.eval(x);
  271. }
  272. real remez_solver::eval_func(real const &x)
  273. {
  274. return m_func.eval(x * m_k2 + m_k1);
  275. }
  276. real remez_solver::eval_weight(real const &x)
  277. {
  278. return m_has_weight ? m_weight.eval(x * m_k2 + m_k1) : real(1);
  279. }
  280. real remez_solver::eval_error(real const &x)
  281. {
  282. return fabs((eval_estimate(x) - eval_func(x)) / eval_weight(x));
  283. }
  284. void remez_solver::worker_thread()
  285. {
  286. static real const zero = (real)0;
  287. for (;;)
  288. {
  289. int i = m_questions.pop();
  290. if (i < 0)
  291. {
  292. m_answers.push(i);
  293. break;
  294. }
  295. else if (i < 1000)
  296. {
  297. point &a = m_zeroes_state[i].m1;
  298. point &b = m_zeroes_state[i].m2;
  299. point &c = m_zeroes_state[i].m3;
  300. real s = abs(b.err) / (abs(a.err) + abs(b.err));
  301. real newc = b.x + s * (a.x - b.x);
  302. /* If the third point didn't change since last iteration,
  303. * we may be at an inflection point. Use the midpoint to get
  304. * out of this situation. */
  305. c.x = newc != c.x ? newc : (a.x + b.x) / 2;
  306. c.err = eval_estimate(c.x) - eval_func(c.x);
  307. if ((a.err < zero && c.err < zero)
  308. || (a.err > zero && c.err > zero))
  309. a = c;
  310. else
  311. b = c;
  312. m_answers.push(i);
  313. }
  314. else if (i < 2000)
  315. {
  316. point &a = m_extrema_state[i - 1000].m1;
  317. point &b = m_extrema_state[i - 1000].m2;
  318. point &c = m_extrema_state[i - 1000].m3;
  319. point d;
  320. real d1 = c.x - a.x, d2 = c.x - b.x;
  321. real k1 = d1 * (c.err - b.err);
  322. real k2 = d2 * (c.err - a.err);
  323. d.x = c.x - (d1 * k1 - d2 * k2) / (k1 - k2) / 2;
  324. /* If parabolic interpolation failed, pick a number
  325. * inbetween. */
  326. if (d.x <= a.x || d.x >= b.x)
  327. d.x = (a.x + b.x) / 2;
  328. d.err = eval_error(d.x);
  329. /* Update bracketing depending on the new point. */
  330. if (d.err < c.err)
  331. {
  332. (d.x > c.x ? b : a) = d;
  333. }
  334. else
  335. {
  336. (d.x > c.x ? a : b) = c;
  337. c = d;
  338. }
  339. m_answers.push(i);
  340. }
  341. }
  342. }