source: git/ntl/src/ZZX.c @ 311902

spielwiese
Last change on this file since 311902 was 311902, checked in by Hans Schönemann <hannes@…>, 20 years ago
*hannes: 5.3.2-C++-fix git-svn-id: file:///usr/local/Singular/svn/trunk@7474 2c84dea3-7e68-4137-9b89-c4e89433aadc
  • Property mode set to 100644
File size: 15.0 KB
Line 
1
2#include <NTL/ZZX.h>
3
4#include <NTL/new.h>
5
6NTL_START_IMPL
7
8
9
10const ZZX& ZZX::zero()
11{
12   static ZZX z;
13   return z;
14}
15
16
17
18void conv(ZZ_pX& x, const ZZX& a)
19{
20   conv(x.rep, a.rep);
21   x.normalize();
22}
23
24void conv(ZZX& x, const ZZ_pX& a)
25{
26   conv(x.rep, a.rep);
27   x.normalize();
28}
29
30
31void ZZX::normalize()
32{
33   long n;
34   const ZZ* p;
35
36   n = rep.length();
37   if (n == 0) return;
38   p = rep.elts() + n;
39   while (n > 0 && IsZero(*--p)) {
40      n--;
41   }
42   rep.SetLength(n);
43}
44
45
46long IsZero(const ZZX& a)
47{
48   return a.rep.length() == 0;
49}
50
51
52long IsOne(const ZZX& a)
53{
54    return a.rep.length() == 1 && IsOne(a.rep[0]);
55}
56
57long operator==(const ZZX& a, const ZZX& b)
58{
59   long i, n;
60   const ZZ *ap, *bp;
61
62   n = a.rep.length();
63   if (n != b.rep.length()) return 0;
64
65   ap = a.rep.elts();
66   bp = b.rep.elts();
67
68   for (i = 0; i < n; i++)
69      if (ap[i] != bp[i]) return 0;
70
71   return 1;
72}
73
74
75long operator==(const ZZX& a, long b)
76{
77   if (b == 0)
78      return IsZero(a);
79
80   if (deg(a) != 0)
81      return 0;
82
83   return a.rep[0] == b;
84}
85
86long operator==(const ZZX& a, const ZZ& b)
87{
88   if (IsZero(b))
89      return IsZero(a);
90
91   if (deg(a) != 0)
92      return 0;
93
94   return a.rep[0] == b;
95}
96
97
98void GetCoeff(ZZ& x, const ZZX& a, long i)
99{
100   if (i < 0 || i > deg(a))
101      clear(x);
102   else
103      x = a.rep[i];
104}
105
106void SetCoeff(ZZX& x, long i, const ZZ& a)
107{
108   long j, m;
109
110   if (i < 0) 
111      Error("SetCoeff: negative index");
112
113   if (NTL_OVERFLOW(i, 1, 0))
114      Error("overflow in SetCoeff");
115
116   m = deg(x);
117
118   if (i > m) {
119      /* careful: a may alias a coefficient of x */
120
121      long alloc = x.rep.allocated();
122
123      if (alloc > 0 && i >= alloc) {
124         ZZ aa = a;
125         x.rep.SetLength(i+1);
126         x.rep[i] = aa;
127      }
128      else {
129         x.rep.SetLength(i+1);
130         x.rep[i] = a;
131      }
132         
133      for (j = m+1; j < i; j++)
134         clear(x.rep[j]);
135   }
136   else
137      x.rep[i] = a;
138
139   x.normalize();
140}
141
142
143void SetCoeff(ZZX& x, long i)
144{
145   long j, m;
146
147   if (i < 0) 
148      Error("coefficient index out of range");
149
150   if (NTL_OVERFLOW(i, 1, 0))
151      Error("overflow in SetCoeff");
152
153   m = deg(x);
154
155   if (i > m) {
156      x.rep.SetLength(i+1);
157      for (j = m+1; j < i; j++)
158         clear(x.rep[j]);
159   }
160   set(x.rep[i]);
161   x.normalize();
162}
163
164
165void SetX(ZZX& x)
166{
167   clear(x);
168   SetCoeff(x, 1);
169}
170
171
172long IsX(const ZZX& a)
173{
174   return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
175}
176     
177     
178
179const ZZ& coeff(const ZZX& a, long i)
180{
181   if (i < 0 || i > deg(a))
182      return ZZ::zero();
183   else
184      return a.rep[i];
185}
186
187
188const ZZ& LeadCoeff(const ZZX& a)
189{
190   if (IsZero(a))
191      return ZZ::zero();
192   else
193      return a.rep[deg(a)];
194}
195
196const ZZ& ConstTerm(const ZZX& a)
197{
198   if (IsZero(a))
199      return ZZ::zero();
200   else
201      return a.rep[0];
202}
203
204
205
206void conv(ZZX& x, const ZZ& a)
207{
208   if (IsZero(a))
209      x.rep.SetLength(0);
210   else {
211      x.rep.SetLength(1);
212      x.rep[0] = a;
213   }
214}
215
216
217void conv(ZZX& x, long a)
218{
219   if (a == 0) 
220      x.rep.SetLength(0);
221   else {
222      x.rep.SetLength(1);
223      conv(x.rep[0], a);
224   }
225}
226
227
228void conv(ZZX& x, const vec_ZZ& a)
229{
230   x.rep = a;
231   x.normalize();
232}
233
234
235void add(ZZX& x, const ZZX& a, const ZZX& b)
236{
237   long da = deg(a);
238   long db = deg(b);
239   long minab = min(da, db);
240   long maxab = max(da, db);
241   x.rep.SetLength(maxab+1);
242
243   long i;
244   const ZZ *ap, *bp; 
245   ZZ* xp;
246
247   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
248        i; i--, ap++, bp++, xp++)
249      add(*xp, (*ap), (*bp));
250
251   if (da > minab && &x != &a)
252      for (i = da-minab; i; i--, xp++, ap++)
253         *xp = *ap;
254   else if (db > minab && &x != &b)
255      for (i = db-minab; i; i--, xp++, bp++)
256         *xp = *bp;
257   else
258      x.normalize();
259}
260
261void add(ZZX& x, const ZZX& a, const ZZ& b)
262{
263   long n = a.rep.length();
264   if (n == 0) {
265      conv(x, b);
266   }
267   else if (&x == &a) {
268      add(x.rep[0], a.rep[0], b);
269      x.normalize();
270   }
271   else if (x.rep.MaxLength() == 0) {
272      x = a;
273      add(x.rep[0], a.rep[0], b);
274      x.normalize();
275   }
276   else {
277      // ugly...b could alias a coeff of x
278
279      ZZ *xp = x.rep.elts(); 
280      add(xp[0], a.rep[0], b);
281      x.rep.SetLength(n);
282      xp = x.rep.elts();
283      const ZZ *ap = a.rep.elts();
284      long i;
285      for (i = 1; i < n; i++)
286         xp[i] = ap[i];
287      x.normalize();
288   }
289}
290
291
292void add(ZZX& x, const ZZX& a, long b)
293{
294   if (a.rep.length() == 0) {
295      conv(x, b);
296   }
297   else {
298      if (&x != &a) x = a;
299      add(x.rep[0], x.rep[0], b);
300      x.normalize();
301   }
302}
303
304
305void sub(ZZX& x, const ZZX& a, const ZZX& b)
306{
307   long da = deg(a);
308   long db = deg(b);
309   long minab = min(da, db);
310   long maxab = max(da, db);
311   x.rep.SetLength(maxab+1);
312
313   long i;
314   const ZZ *ap, *bp; 
315   ZZ* xp;
316
317   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
318        i; i--, ap++, bp++, xp++)
319      sub(*xp, (*ap), (*bp));
320
321   if (da > minab && &x != &a)
322      for (i = da-minab; i; i--, xp++, ap++)
323         *xp = *ap;
324   else if (db > minab)
325      for (i = db-minab; i; i--, xp++, bp++)
326         negate(*xp, *bp);
327   else
328      x.normalize();
329
330}
331
332void sub(ZZX& x, const ZZX& a, const ZZ& b)
333{
334   long n = a.rep.length();
335   if (n == 0) {
336      conv(x, b);
337      negate(x, x);
338   }
339   else if (&x == &a) {
340      sub(x.rep[0], a.rep[0], b);
341      x.normalize();
342   }
343   else if (x.rep.MaxLength() == 0) {
344      x = a;
345      sub(x.rep[0], a.rep[0], b);
346      x.normalize();
347   }
348   else {
349      // ugly...b could alias a coeff of x
350
351      ZZ *xp = x.rep.elts();
352      sub(xp[0], a.rep[0], b);
353      x.rep.SetLength(n);
354      xp = x.rep.elts();
355      const ZZ *ap = a.rep.elts();
356      long i;
357      for (i = 1; i < n; i++)
358         xp[i] = ap[i];
359      x.normalize();
360   }
361}
362
363void sub(ZZX& x, const ZZX& a, long b)
364{
365   if (b == 0) {
366      x = a;
367      return;
368   }
369
370   if (a.rep.length() == 0) {
371      x.rep.SetLength(1);
372      conv(x.rep[0], b);
373      negate(x.rep[0], x.rep[0]);
374   }
375   else {
376      if (&x != &a) x = a;
377      sub(x.rep[0], x.rep[0], b);
378   }
379   x.normalize();
380}
381
382void sub(ZZX& x, long a, const ZZX& b)
383{
384   negate(x, b);
385   add(x, x, a);
386}
387
388
389void sub(ZZX& x, const ZZ& b, const ZZX& a)
390{
391   long n = a.rep.length();
392   if (n == 0) {
393      conv(x, b);
394   }
395   else if (x.rep.MaxLength() == 0) {
396      negate(x, a);
397      add(x.rep[0], a.rep[0], b);
398      x.normalize();
399   }
400   else {
401      // ugly...b could alias a coeff of x
402
403      ZZ *xp = x.rep.elts();
404      sub(xp[0], b, a.rep[0]);
405      x.rep.SetLength(n);
406      xp = x.rep.elts();
407      const ZZ *ap = a.rep.elts();
408      long i;
409      for (i = 1; i < n; i++)
410         negate(xp[i], ap[i]);
411      x.normalize();
412   }
413}
414
415
416
417void negate(ZZX& x, const ZZX& a)
418{
419   long n = a.rep.length();
420   x.rep.SetLength(n);
421
422   const ZZ* ap = a.rep.elts();
423   ZZ* xp = x.rep.elts();
424   long i;
425
426   for (i = n; i; i--, ap++, xp++)
427      negate((*xp), (*ap));
428
429}
430
431long MaxBits(const ZZX& f)
432{
433   long i, m;
434   m = 0;
435
436   for (i = 0; i <= deg(f); i++) {
437      m = max(m, NumBits(f.rep[i]));
438   }
439
440   return m;
441}
442
443
444void PlainMul(ZZX& x, const ZZX& a, const ZZX& b)
445{
446   if (&a == &b) {
447      PlainSqr(x, a);
448      return;
449   }
450
451   long da = deg(a);
452   long db = deg(b);
453
454   if (da < 0 || db < 0) {
455      clear(x);
456      return;
457   }
458
459   long d = da+db;
460
461
462
463   const ZZ *ap, *bp;
464   ZZ *xp;
465   
466   ZZX la, lb;
467
468   if (&x == &a) {
469      la = a;
470      ap = la.rep.elts();
471   }
472   else
473      ap = a.rep.elts();
474
475   if (&x == &b) {
476      lb = b;
477      bp = lb.rep.elts();
478   }
479   else
480      bp = b.rep.elts();
481
482   x.rep.SetLength(d+1);
483
484   xp = x.rep.elts();
485
486   long i, j, jmin, jmax;
487   ZZ t, accum;
488
489   for (i = 0; i <= d; i++) {
490      jmin = max(0, i-db);
491      jmax = min(da, i);
492      clear(accum);
493      for (j = jmin; j <= jmax; j++) {
494         mul(t, ap[j], bp[i-j]);
495         add(accum, accum, t);
496      }
497      xp[i] = accum;
498   }
499   x.normalize();
500}
501
502void PlainSqr(ZZX& x, const ZZX& a)
503{
504   long da = deg(a);
505
506   if (da < 0) {
507      clear(x);
508      return;
509   }
510
511   long d = 2*da;
512
513   const ZZ *ap;
514   ZZ *xp;
515
516   ZZX la;
517
518   if (&x == &a) {
519      la = a;
520      ap = la.rep.elts();
521   }
522   else
523      ap = a.rep.elts();
524
525
526   x.rep.SetLength(d+1);
527
528   xp = x.rep.elts();
529
530   long i, j, jmin, jmax;
531   long m, m2;
532   ZZ t, accum;
533
534   for (i = 0; i <= d; i++) {
535      jmin = max(0, i-da);
536      jmax = min(da, i);
537      m = jmax - jmin + 1;
538      m2 = m >> 1;
539      jmax = jmin + m2 - 1;
540      clear(accum);
541      for (j = jmin; j <= jmax; j++) {
542         mul(t, ap[j], ap[i-j]);
543         add(accum, accum, t);
544      }
545      add(accum, accum, accum);
546      if (m & 1) {
547         sqr(t, ap[jmax + 1]);
548         add(accum, accum, t);
549      }
550
551      xp[i] = accum;
552   }
553
554   x.normalize();
555}
556
557
558
559static
560void PlainMul(ZZ *xp, const ZZ *ap, long sa, const ZZ *bp, long sb)
561{
562   if (sa == 0 || sb == 0) return;
563
564   long sx = sa+sb-1;
565
566   long i, j, jmin, jmax;
567   static ZZ t, accum;
568
569   for (i = 0; i < sx; i++) {
570      jmin = max(0, i-sb+1);
571      jmax = min(sa-1, i);
572      clear(accum);
573      for (j = jmin; j <= jmax; j++) {
574         mul(t, ap[j], bp[i-j]);
575         add(accum, accum, t);
576      }
577      xp[i] = accum;
578   }
579}
580
581
582
583static
584void KarFold(ZZ *T, const ZZ *b, long sb, long hsa)
585{
586   long m = sb - hsa;
587   long i;
588
589   for (i = 0; i < m; i++)
590      add(T[i], b[i], b[hsa+i]);
591
592   for (i = m; i < hsa; i++)
593      T[i] = b[i];
594}
595
596static
597void KarSub(ZZ *T, const ZZ *b, long sb)
598{
599   long i;
600
601   for (i = 0; i < sb; i++)
602      sub(T[i], T[i], b[i]);
603}
604
605static
606void KarAdd(ZZ *T, const ZZ *b, long sb)
607{
608   long i;
609
610   for (i = 0; i < sb; i++)
611      add(T[i], T[i], b[i]);
612}
613
614static
615void KarFix(ZZ *c, const ZZ *b, long sb, long hsa)
616{
617   long i;
618
619   for (i = 0; i < hsa; i++)
620      c[i] = b[i];
621
622   for (i = hsa; i < sb; i++)
623      add(c[i], c[i], b[i]);
624}
625
626static void PlainMul1(ZZ *xp, const ZZ *ap, long sa, const ZZ& b)
627{
628   long i;
629
630   for (i = 0; i < sa; i++)
631      mul(xp[i], ap[i], b);
632}
633
634
635
636static
637void KarMul(ZZ *c, const ZZ *a, 
638            long sa, const ZZ *b, long sb, ZZ *stk)
639{
640   if (sa < sb) {
641      { long t = sa; sa = sb; sb = t; }
642      { const ZZ *t = a; a = b; b = t; }
643   }
644
645   if (sb == 1) {
646      if (sa == 1)
647         mul(*c, *a, *b);
648      else
649         PlainMul1(c, a, sa, *b);
650
651      return;
652   }
653
654   if (sb == 2 && sa == 2) {
655      mul(c[0], a[0], b[0]);
656      mul(c[2], a[1], b[1]);
657      add(stk[0], a[0], a[1]);
658      add(stk[1], b[0], b[1]);
659      mul(c[1], stk[0], stk[1]);
660      sub(c[1], c[1], c[0]);
661      sub(c[1], c[1], c[2]);
662
663      return;
664
665   }
666
667   long hsa = (sa + 1) >> 1;
668
669   if (hsa < sb) {
670      /* normal case */
671
672      long hsa2 = hsa << 1;
673
674      ZZ *T1, *T2, *T3;
675
676      T1 = stk; stk += hsa;
677      T2 = stk; stk += hsa;
678      T3 = stk; stk += hsa2 - 1;
679
680      /* compute T1 = a_lo + a_hi */
681
682      KarFold(T1, a, sa, hsa);
683
684      /* compute T2 = b_lo + b_hi */
685
686      KarFold(T2, b, sb, hsa);
687
688      /* recursively compute T3 = T1 * T2 */
689
690      KarMul(T3, T1, hsa, T2, hsa, stk);
691
692      /* recursively compute a_hi * b_hi into high part of c */
693      /* and subtract from T3 */
694
695      KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);
696      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);
697
698
699      /* recursively compute a_lo*b_lo into low part of c */
700      /* and subtract from T3 */
701
702      KarMul(c, a, hsa, b, hsa, stk);
703      KarSub(T3, c, hsa2 - 1);
704
705      clear(c[hsa2 - 1]);
706
707      /* finally, add T3 * X^{hsa} to c */
708
709      KarAdd(c+hsa, T3, hsa2-1);
710   }
711   else {
712      /* degenerate case */
713
714      ZZ *T;
715
716      T = stk; stk += hsa + sb - 1;
717
718      /* recursively compute b*a_hi into high part of c */
719
720      KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk);
721
722      /* recursively compute b*a_lo into T */
723
724      KarMul(T, a, hsa, b, sb, stk);
725
726      KarFix(c, T, hsa + sb - 1, hsa);
727   }
728}
729
730void KarMul(ZZX& c, const ZZX& a, const ZZX& b)
731{
732   if (IsZero(a) || IsZero(b)) {
733      clear(c);
734      return;
735   }
736
737   if (&a == &b) {
738      KarSqr(c, a);
739      return;
740   }
741
742   vec_ZZ mem;
743
744   const ZZ *ap, *bp;
745   ZZ *cp;
746
747   long sa = a.rep.length();
748   long sb = b.rep.length();
749
750   if (&a == &c) {
751      mem = a.rep;
752      ap = mem.elts();
753   }
754   else
755      ap = a.rep.elts();
756
757   if (&b == &c) {
758      mem = b.rep;
759      bp = mem.elts();
760   }
761   else
762      bp = b.rep.elts();
763
764   c.rep.SetLength(sa+sb-1);
765   cp = c.rep.elts();
766
767   long maxa, maxb, xover;
768
769   maxa = MaxBits(a);
770   maxb = MaxBits(b);
771   xover = 2;
772
773   if (sa < xover || sb < xover)
774      PlainMul(cp, ap, sa, bp, sb);
775   else {
776      /* karatsuba */
777
778      long n, hn, sp, depth;
779
780      n = max(sa, sb);
781      sp = 0;
782      depth = 0;
783      do {
784         hn = (n+1) >> 1;
785         sp += (hn << 2) - 1;
786         n = hn;
787         depth++;
788      } while (n >= xover);
789
790      ZZVec stk;
791      stk.SetSize(sp, 
792         ((maxa + maxb + NumBits(min(sa, sb)) + 2*depth + 10) 
793          + NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
794
795      KarMul(cp, ap, sa, bp, sb, stk.elts());
796   }
797
798   c.normalize();
799}
800
801
802
803
804
805
806void PlainSqr(ZZ* xp, const ZZ* ap, long sa)
807{
808   if (sa == 0) return;
809
810   long da = sa-1;
811   long d = 2*da;
812
813   long i, j, jmin, jmax;
814   long m, m2;
815   static ZZ t, accum;
816
817   for (i = 0; i <= d; i++) {
818      jmin = max(0, i-da);
819      jmax = min(da, i);
820      m = jmax - jmin + 1;
821      m2 = m >> 1;
822      jmax = jmin + m2 - 1;
823      clear(accum);
824      for (j = jmin; j <= jmax; j++) {
825         mul(t, ap[j], ap[i-j]);
826         add(accum, accum, t);
827      }
828      add(accum, accum, accum);
829      if (m & 1) {
830         sqr(t, ap[jmax + 1]);
831         add(accum, accum, t);
832      }
833
834      xp[i] = accum;
835   }
836}
837
838
839static
840void KarSqr(ZZ *c, const ZZ *a, long sa, ZZ *stk)
841{
842   if (sa == 1) {
843      sqr(*c, *a);
844      return;
845   }
846
847   if (sa == 2) {
848      sqr(c[0], a[0]);
849      sqr(c[2], a[1]);
850      mul(c[1], a[0], a[1]);
851      add(c[1], c[1], c[1]);
852
853      return;
854   }
855
856   if (sa == 3) {
857      sqr(c[0], a[0]);
858      mul(c[1], a[0], a[1]);
859      add(c[1], c[1], c[1]);
860      sqr(stk[0], a[1]);
861      mul(c[2], a[0], a[2]);
862      add(c[2], c[2], c[2]);
863      add(c[2], c[2], stk[0]);
864      mul(c[3], a[1], a[2]);
865      add(c[3], c[3], c[3]);
866      sqr(c[4], a[2]);
867
868      return;
869 
870   }
871
872   long hsa = (sa + 1) >> 1;
873   long hsa2 = hsa << 1;
874
875   ZZ *T1, *T2;
876
877   T1 = stk; stk += hsa;
878   T2 = stk; stk += hsa2-1;
879
880   KarFold(T1, a, sa, hsa);
881   KarSqr(T2, T1, hsa, stk);
882
883
884   KarSqr(c + hsa2, a+hsa, sa-hsa, stk);
885   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);
886
887
888   KarSqr(c, a, hsa, stk);
889   KarSub(T2, c, hsa2 - 1);
890
891   clear(c[hsa2 - 1]);
892
893   KarAdd(c+hsa, T2, hsa2-1);
894}
895
896     
897void KarSqr(ZZX& c, const ZZX& a)
898{
899   if (IsZero(a)) {
900      clear(c);
901      return;
902   }
903
904   vec_ZZ mem;
905
906   const ZZ *ap;
907   ZZ *cp;
908
909   long sa = a.rep.length();
910
911   if (&a == &c) {
912      mem = a.rep;
913      ap = mem.elts();
914   }
915   else
916      ap = a.rep.elts();
917
918   c.rep.SetLength(sa+sa-1);
919   cp = c.rep.elts();
920
921   long maxa, xover;
922
923   maxa = MaxBits(a);
924
925   xover = 2;
926
927   if (sa < xover)
928      PlainSqr(cp, ap, sa);
929   else {
930      /* karatsuba */
931
932      long n, hn, sp, depth;
933
934      n = sa;
935      sp = 0;
936      depth = 0;
937      do {
938         hn = (n+1) >> 1;
939         sp += hn+hn+hn - 1;
940         n = hn;
941         depth++;
942      } while (n >= xover);
943
944      ZZVec stk;
945      stk.SetSize(sp, 
946         ((2*maxa + NumBits(sa) + 2*depth + 10) 
947          + NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
948
949      KarSqr(cp, ap, sa, stk.elts());
950   }
951
952   c.normalize();
953}
954
955NTL_END_IMPL
Note: See TracBrowser for help on using the repository browser.