source: git/ntl/src/ZZX.c @ 287cc8

spielwiese
Last change on this file since 287cc8 was 287cc8, checked in by Hans Schönemann <hannes@…>, 14 years ago
ntl 5.5.2 git-svn-id: file:///usr/local/Singular/svn/trunk@12402 2c84dea3-7e68-4137-9b89-c4e89433aadc
  • Property mode set to 100644
File size: 15.0 KB
Line 
1
2#include <NTL/ZZX.h>
3
4#include <NTL/new.h>
5
6NTL_START_IMPL
7
8
9
10const ZZX& ZZX::zero()
11{
12   static ZZX z;
13   return z;
14}
15
16
17
18void conv(ZZ_pX& x, const ZZX& a)
19{
20   conv(x.rep, a.rep);
21   x.normalize();
22}
23
24void conv(ZZX& x, const ZZ_pX& a)
25{
26   conv(x.rep, a.rep);
27   x.normalize();
28}
29
30
31void ZZX::normalize()
32{
33   long n;
34   const ZZ* p;
35
36   n = rep.length();
37   if (n == 0) return;
38   p = rep.elts() + n;
39   while (n > 0 && IsZero(*--p)) {
40      n--;
41   }
42   rep.SetLength(n);
43}
44
45
46long IsZero(const ZZX& a)
47{
48   return a.rep.length() == 0;
49}
50
51
52long IsOne(const ZZX& a)
53{
54    return a.rep.length() == 1 && IsOne(a.rep[0]);
55}
56
57long operator==(const ZZX& a, const ZZX& b)
58{
59   long i, n;
60   const ZZ *ap, *bp;
61
62   n = a.rep.length();
63   if (n != b.rep.length()) return 0;
64
65   ap = a.rep.elts();
66   bp = b.rep.elts();
67
68   for (i = 0; i < n; i++)
69      if (ap[i] != bp[i]) return 0;
70
71   return 1;
72}
73
74
75long operator==(const ZZX& a, long b)
76{
77   if (b == 0)
78      return IsZero(a);
79
80   if (deg(a) != 0)
81      return 0;
82
83   return a.rep[0] == b;
84}
85
86long operator==(const ZZX& a, const ZZ& b)
87{
88   if (IsZero(b))
89      return IsZero(a);
90
91   if (deg(a) != 0)
92      return 0;
93
94   return a.rep[0] == b;
95}
96
97
98void GetCoeff(ZZ& x, const ZZX& a, long i)
99{
100   if (i < 0 || i > deg(a))
101      clear(x);
102   else
103      x = a.rep[i];
104}
105
106void SetCoeff(ZZX& x, long i, const ZZ& a)
107{
108   long j, m;
109
110   if (i < 0)
111      Error("SetCoeff: negative index");
112
113   if (NTL_OVERFLOW(i, 1, 0))
114      Error("overflow in SetCoeff");
115
116   m = deg(x);
117
118   if (i > m && IsZero(a)) return;
119
120   if (i > m) {
121      /* careful: a may alias a coefficient of x */
122
123      long alloc = x.rep.allocated();
124
125      if (alloc > 0 && i >= alloc) {
126         ZZ aa = a;
127         x.rep.SetLength(i+1);
128         x.rep[i] = aa;
129      }
130      else {
131         x.rep.SetLength(i+1);
132         x.rep[i] = a;
133      }
134
135      for (j = m+1; j < i; j++)
136         clear(x.rep[j]);
137   }
138   else
139      x.rep[i] = a;
140
141   x.normalize();
142}
143
144
145void SetCoeff(ZZX& x, long i)
146{
147   long j, m;
148
149   if (i < 0)
150      Error("coefficient index out of range");
151
152   if (NTL_OVERFLOW(i, 1, 0))
153      Error("overflow in SetCoeff");
154
155   m = deg(x);
156
157   if (i > m) {
158      x.rep.SetLength(i+1);
159      for (j = m+1; j < i; j++)
160         clear(x.rep[j]);
161   }
162   set(x.rep[i]);
163   x.normalize();
164}
165
166
167void SetX(ZZX& x)
168{
169   clear(x);
170   SetCoeff(x, 1);
171}
172
173
174long IsX(const ZZX& a)
175{
176   return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
177}
178
179
180
181const ZZ& coeff(const ZZX& a, long i)
182{
183   if (i < 0 || i > deg(a))
184      return ZZ::zero();
185   else
186      return a.rep[i];
187}
188
189
190const ZZ& LeadCoeff(const ZZX& a)
191{
192   if (IsZero(a))
193      return ZZ::zero();
194   else
195      return a.rep[deg(a)];
196}
197
198const ZZ& ConstTerm(const ZZX& a)
199{
200   if (IsZero(a))
201      return ZZ::zero();
202   else
203      return a.rep[0];
204}
205
206
207
208void conv(ZZX& x, const ZZ& a)
209{
210   if (IsZero(a))
211      x.rep.SetLength(0);
212   else {
213      x.rep.SetLength(1);
214      x.rep[0] = a;
215   }
216}
217
218
219void conv(ZZX& x, long a)
220{
221   if (a == 0)
222      x.rep.SetLength(0);
223   else {
224      x.rep.SetLength(1);
225      conv(x.rep[0], a);
226   }
227}
228
229
230void conv(ZZX& x, const vec_ZZ& a)
231{
232   x.rep = a;
233   x.normalize();
234}
235
236
237void add(ZZX& x, const ZZX& a, const ZZX& b)
238{
239   long da = deg(a);
240   long db = deg(b);
241   long minab = min(da, db);
242   long maxab = max(da, db);
243   x.rep.SetLength(maxab+1);
244
245   long i;
246   const ZZ *ap, *bp;
247   ZZ* xp;
248
249   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
250        i; i--, ap++, bp++, xp++)
251      add(*xp, (*ap), (*bp));
252
253   if (da > minab && &x != &a)
254      for (i = da-minab; i; i--, xp++, ap++)
255         *xp = *ap;
256   else if (db > minab && &x != &b)
257      for (i = db-minab; i; i--, xp++, bp++)
258         *xp = *bp;
259   else
260      x.normalize();
261}
262
263void add(ZZX& x, const ZZX& a, const ZZ& b)
264{
265   long n = a.rep.length();
266   if (n == 0) {
267      conv(x, b);
268   }
269   else if (&x == &a) {
270      add(x.rep[0], a.rep[0], b);
271      x.normalize();
272   }
273   else if (x.rep.MaxLength() == 0) {
274      x = a;
275      add(x.rep[0], a.rep[0], b);
276      x.normalize();
277   }
278   else {
279      // ugly...b could alias a coeff of x
280
281      ZZ *xp = x.rep.elts();
282      add(xp[0], a.rep[0], b);
283      x.rep.SetLength(n);
284      xp = x.rep.elts();
285      const ZZ *ap = a.rep.elts();
286      long i;
287      for (i = 1; i < n; i++)
288         xp[i] = ap[i];
289      x.normalize();
290   }
291}
292
293
294void add(ZZX& x, const ZZX& a, long b)
295{
296   if (a.rep.length() == 0) {
297      conv(x, b);
298   }
299   else {
300      if (&x != &a) x = a;
301      add(x.rep[0], x.rep[0], b);
302      x.normalize();
303   }
304}
305
306
307void sub(ZZX& x, const ZZX& a, const ZZX& b)
308{
309   long da = deg(a);
310   long db = deg(b);
311   long minab = min(da, db);
312   long maxab = max(da, db);
313   x.rep.SetLength(maxab+1);
314
315   long i;
316   const ZZ *ap, *bp;
317   ZZ* xp;
318
319   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
320        i; i--, ap++, bp++, xp++)
321      sub(*xp, (*ap), (*bp));
322
323   if (da > minab && &x != &a)
324      for (i = da-minab; i; i--, xp++, ap++)
325         *xp = *ap;
326   else if (db > minab)
327      for (i = db-minab; i; i--, xp++, bp++)
328         negate(*xp, *bp);
329   else
330      x.normalize();
331
332}
333
334void sub(ZZX& x, const ZZX& a, const ZZ& b)
335{
336   long n = a.rep.length();
337   if (n == 0) {
338      conv(x, b);
339      negate(x, x);
340   }
341   else if (&x == &a) {
342      sub(x.rep[0], a.rep[0], b);
343      x.normalize();
344   }
345   else if (x.rep.MaxLength() == 0) {
346      x = a;
347      sub(x.rep[0], a.rep[0], b);
348      x.normalize();
349   }
350   else {
351      // ugly...b could alias a coeff of x
352
353      ZZ *xp = x.rep.elts();
354      sub(xp[0], a.rep[0], b);
355      x.rep.SetLength(n);
356      xp = x.rep.elts();
357      const ZZ *ap = a.rep.elts();
358      long i;
359      for (i = 1; i < n; i++)
360         xp[i] = ap[i];
361      x.normalize();
362   }
363}
364
365void sub(ZZX& x, const ZZX& a, long b)
366{
367   if (b == 0) {
368      x = a;
369      return;
370   }
371
372   if (a.rep.length() == 0) {
373      x.rep.SetLength(1);
374      conv(x.rep[0], b);
375      negate(x.rep[0], x.rep[0]);
376   }
377   else {
378      if (&x != &a) x = a;
379      sub(x.rep[0], x.rep[0], b);
380   }
381   x.normalize();
382}
383
384void sub(ZZX& x, long a, const ZZX& b)
385{
386   negate(x, b);
387   add(x, x, a);
388}
389
390
391void sub(ZZX& x, const ZZ& b, const ZZX& a)
392{
393   long n = a.rep.length();
394   if (n == 0) {
395      conv(x, b);
396   }
397   else if (x.rep.MaxLength() == 0) {
398      negate(x, a);
399      add(x.rep[0], a.rep[0], b);
400      x.normalize();
401   }
402   else {
403      // ugly...b could alias a coeff of x
404
405      ZZ *xp = x.rep.elts();
406      sub(xp[0], b, a.rep[0]);
407      x.rep.SetLength(n);
408      xp = x.rep.elts();
409      const ZZ *ap = a.rep.elts();
410      long i;
411      for (i = 1; i < n; i++)
412         negate(xp[i], ap[i]);
413      x.normalize();
414   }
415}
416
417
418
419void negate(ZZX& x, const ZZX& a)
420{
421   long n = a.rep.length();
422   x.rep.SetLength(n);
423
424   const ZZ* ap = a.rep.elts();
425   ZZ* xp = x.rep.elts();
426   long i;
427
428   for (i = n; i; i--, ap++, xp++)
429      negate((*xp), (*ap));
430
431}
432
433long MaxBits(const ZZX& f)
434{
435   long i, m;
436   m = 0;
437
438   for (i = 0; i <= deg(f); i++) {
439      m = max(m, NumBits(f.rep[i]));
440   }
441
442   return m;
443}
444
445
446void PlainMul(ZZX& x, const ZZX& a, const ZZX& b)
447{
448   if (&a == &b) {
449      PlainSqr(x, a);
450      return;
451   }
452
453   long da = deg(a);
454   long db = deg(b);
455
456   if (da < 0 || db < 0) {
457      clear(x);
458      return;
459   }
460
461   long d = da+db;
462
463
464
465   const ZZ *ap, *bp;
466   ZZ *xp;
467
468   ZZX la, lb;
469
470   if (&x == &a) {
471      la = a;
472      ap = la.rep.elts();
473   }
474   else
475      ap = a.rep.elts();
476
477   if (&x == &b) {
478      lb = b;
479      bp = lb.rep.elts();
480   }
481   else
482      bp = b.rep.elts();
483
484   x.rep.SetLength(d+1);
485
486   xp = x.rep.elts();
487
488   long i, j, jmin, jmax;
489   ZZ t, accum;
490
491   for (i = 0; i <= d; i++) {
492      jmin = max(0, i-db);
493      jmax = min(da, i);
494      clear(accum);
495      for (j = jmin; j <= jmax; j++) {
496         mul(t, ap[j], bp[i-j]);
497         add(accum, accum, t);
498      }
499      xp[i] = accum;
500   }
501   x.normalize();
502}
503
504void PlainSqr(ZZX& x, const ZZX& a)
505{
506   long da = deg(a);
507
508   if (da < 0) {
509      clear(x);
510      return;
511   }
512
513   long d = 2*da;
514
515   const ZZ *ap;
516   ZZ *xp;
517
518   ZZX la;
519
520   if (&x == &a) {
521      la = a;
522      ap = la.rep.elts();
523   }
524   else
525      ap = a.rep.elts();
526
527
528   x.rep.SetLength(d+1);
529
530   xp = x.rep.elts();
531
532   long i, j, jmin, jmax;
533   long m, m2;
534   ZZ t, accum;
535
536   for (i = 0; i <= d; i++) {
537      jmin = max(0, i-da);
538      jmax = min(da, i);
539      m = jmax - jmin + 1;
540      m2 = m >> 1;
541      jmax = jmin + m2 - 1;
542      clear(accum);
543      for (j = jmin; j <= jmax; j++) {
544         mul(t, ap[j], ap[i-j]);
545         add(accum, accum, t);
546      }
547      add(accum, accum, accum);
548      if (m & 1) {
549         sqr(t, ap[jmax + 1]);
550         add(accum, accum, t);
551      }
552
553      xp[i] = accum;
554   }
555
556   x.normalize();
557}
558
559
560
561static
562void PlainMul(ZZ *xp, const ZZ *ap, long sa, const ZZ *bp, long sb)
563{
564   if (sa == 0 || sb == 0) return;
565
566   long sx = sa+sb-1;
567
568   long i, j, jmin, jmax;
569   static ZZ t, accum;
570
571   for (i = 0; i < sx; i++) {
572      jmin = max(0, i-sb+1);
573      jmax = min(sa-1, i);
574      clear(accum);
575      for (j = jmin; j <= jmax; j++) {
576         mul(t, ap[j], bp[i-j]);
577         add(accum, accum, t);
578      }
579      xp[i] = accum;
580   }
581}
582
583
584
585static
586void KarFold(ZZ *T, const ZZ *b, long sb, long hsa)
587{
588   long m = sb - hsa;
589   long i;
590
591   for (i = 0; i < m; i++)
592      add(T[i], b[i], b[hsa+i]);
593
594   for (i = m; i < hsa; i++)
595      T[i] = b[i];
596}
597
598static
599void KarSub(ZZ *T, const ZZ *b, long sb)
600{
601   long i;
602
603   for (i = 0; i < sb; i++)
604      sub(T[i], T[i], b[i]);
605}
606
607static
608void KarAdd(ZZ *T, const ZZ *b, long sb)
609{
610   long i;
611
612   for (i = 0; i < sb; i++)
613      add(T[i], T[i], b[i]);
614}
615
616static
617void KarFix(ZZ *c, const ZZ *b, long sb, long hsa)
618{
619   long i;
620
621   for (i = 0; i < hsa; i++)
622      c[i] = b[i];
623
624   for (i = hsa; i < sb; i++)
625      add(c[i], c[i], b[i]);
626}
627
628static void PlainMul1(ZZ *xp, const ZZ *ap, long sa, const ZZ& b)
629{
630   long i;
631
632   for (i = 0; i < sa; i++)
633      mul(xp[i], ap[i], b);
634}
635
636
637
638static
639void KarMul(ZZ *c, const ZZ *a,
640            long sa, const ZZ *b, long sb, ZZ *stk)
641{
642   if (sa < sb) {
643      { long t = sa; sa = sb; sb = t; }
644      { const ZZ *t = a; a = b; b = t; }
645   }
646
647   if (sb == 1) {
648      if (sa == 1)
649         mul(*c, *a, *b);
650      else
651         PlainMul1(c, a, sa, *b);
652
653      return;
654   }
655
656   if (sb == 2 && sa == 2) {
657      mul(c[0], a[0], b[0]);
658      mul(c[2], a[1], b[1]);
659      add(stk[0], a[0], a[1]);
660      add(stk[1], b[0], b[1]);
661      mul(c[1], stk[0], stk[1]);
662      sub(c[1], c[1], c[0]);
663      sub(c[1], c[1], c[2]);
664
665      return;
666
667   }
668
669   long hsa = (sa + 1) >> 1;
670
671   if (hsa < sb) {
672      /* normal case */
673
674      long hsa2 = hsa << 1;
675
676      ZZ *T1, *T2, *T3;
677
678      T1 = stk; stk += hsa;
679      T2 = stk; stk += hsa;
680      T3 = stk; stk += hsa2 - 1;
681
682      /* compute T1 = a_lo + a_hi */
683
684      KarFold(T1, a, sa, hsa);
685
686      /* compute T2 = b_lo + b_hi */
687
688      KarFold(T2, b, sb, hsa);
689
690      /* recursively compute T3 = T1 * T2 */
691
692      KarMul(T3, T1, hsa, T2, hsa, stk);
693
694      /* recursively compute a_hi * b_hi into high part of c */
695      /* and subtract from T3 */
696
697      KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);
698      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);
699
700
701      /* recursively compute a_lo*b_lo into low part of c */
702      /* and subtract from T3 */
703
704      KarMul(c, a, hsa, b, hsa, stk);
705      KarSub(T3, c, hsa2 - 1);
706
707      clear(c[hsa2 - 1]);
708
709      /* finally, add T3 * X^{hsa} to c */
710
711      KarAdd(c+hsa, T3, hsa2-1);
712   }
713   else {
714      /* degenerate case */
715
716      ZZ *T;
717
718      T = stk; stk += hsa + sb - 1;
719
720      /* recursively compute b*a_hi into high part of c */
721
722      KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk);
723
724      /* recursively compute b*a_lo into T */
725
726      KarMul(T, a, hsa, b, sb, stk);
727
728      KarFix(c, T, hsa + sb - 1, hsa);
729   }
730}
731
732void KarMul(ZZX& c, const ZZX& a, const ZZX& b)
733{
734   if (IsZero(a) || IsZero(b)) {
735      clear(c);
736      return;
737   }
738
739   if (&a == &b) {
740      KarSqr(c, a);
741      return;
742   }
743
744   vec_ZZ mem;
745
746   const ZZ *ap, *bp;
747   ZZ *cp;
748
749   long sa = a.rep.length();
750   long sb = b.rep.length();
751
752   if (&a == &c) {
753      mem = a.rep;
754      ap = mem.elts();
755   }
756   else
757      ap = a.rep.elts();
758
759   if (&b == &c) {
760      mem = b.rep;
761      bp = mem.elts();
762   }
763   else
764      bp = b.rep.elts();
765
766   c.rep.SetLength(sa+sb-1);
767   cp = c.rep.elts();
768
769   long maxa, maxb, xover;
770
771   maxa = MaxBits(a);
772   maxb = MaxBits(b);
773   xover = 2;
774
775   if (sa < xover || sb < xover)
776      PlainMul(cp, ap, sa, bp, sb);
777   else {
778      /* karatsuba */
779
780      long n, hn, sp, depth;
781
782      n = max(sa, sb);
783      sp = 0;
784      depth = 0;
785      do {
786         hn = (n+1) >> 1;
787         sp += (hn << 2) - 1;
788         n = hn;
789         depth++;
790      } while (n >= xover);
791
792      ZZVec stk;
793      stk.SetSize(sp,
794         ((maxa + maxb + NumBits(min(sa, sb)) + 2*depth + 10)
795          + NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
796
797      KarMul(cp, ap, sa, bp, sb, stk.elts());
798   }
799
800   c.normalize();
801}
802
803
804
805
806
807
808void PlainSqr(ZZ* xp, const ZZ* ap, long sa)
809{
810   if (sa == 0) return;
811
812   long da = sa-1;
813   long d = 2*da;
814
815   long i, j, jmin, jmax;
816   long m, m2;
817   static ZZ t, accum;
818
819   for (i = 0; i <= d; i++) {
820      jmin = max(0, i-da);
821      jmax = min(da, i);
822      m = jmax - jmin + 1;
823      m2 = m >> 1;
824      jmax = jmin + m2 - 1;
825      clear(accum);
826      for (j = jmin; j <= jmax; j++) {
827         mul(t, ap[j], ap[i-j]);
828         add(accum, accum, t);
829      }
830      add(accum, accum, accum);
831      if (m & 1) {
832         sqr(t, ap[jmax + 1]);
833         add(accum, accum, t);
834      }
835
836      xp[i] = accum;
837   }
838}
839
840
841static
842void KarSqr(ZZ *c, const ZZ *a, long sa, ZZ *stk)
843{
844   if (sa == 1) {
845      sqr(*c, *a);
846      return;
847   }
848
849   if (sa == 2) {
850      sqr(c[0], a[0]);
851      sqr(c[2], a[1]);
852      mul(c[1], a[0], a[1]);
853      add(c[1], c[1], c[1]);
854
855      return;
856   }
857
858   if (sa == 3) {
859      sqr(c[0], a[0]);
860      mul(c[1], a[0], a[1]);
861      add(c[1], c[1], c[1]);
862      sqr(stk[0], a[1]);
863      mul(c[2], a[0], a[2]);
864      add(c[2], c[2], c[2]);
865      add(c[2], c[2], stk[0]);
866      mul(c[3], a[1], a[2]);
867      add(c[3], c[3], c[3]);
868      sqr(c[4], a[2]);
869
870      return;
871
872   }
873
874   long hsa = (sa + 1) >> 1;
875   long hsa2 = hsa << 1;
876
877   ZZ *T1, *T2;
878
879   T1 = stk; stk += hsa;
880   T2 = stk; stk += hsa2-1;
881
882   KarFold(T1, a, sa, hsa);
883   KarSqr(T2, T1, hsa, stk);
884
885
886   KarSqr(c + hsa2, a+hsa, sa-hsa, stk);
887   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);
888
889
890   KarSqr(c, a, hsa, stk);
891   KarSub(T2, c, hsa2 - 1);
892
893   clear(c[hsa2 - 1]);
894
895   KarAdd(c+hsa, T2, hsa2-1);
896}
897
898
899void KarSqr(ZZX& c, const ZZX& a)
900{
901   if (IsZero(a)) {
902      clear(c);
903      return;
904   }
905
906   vec_ZZ mem;
907
908   const ZZ *ap;
909   ZZ *cp;
910
911   long sa = a.rep.length();
912
913   if (&a == &c) {
914      mem = a.rep;
915      ap = mem.elts();
916   }
917   else
918      ap = a.rep.elts();
919
920   c.rep.SetLength(sa+sa-1);
921   cp = c.rep.elts();
922
923   long maxa, xover;
924
925   maxa = MaxBits(a);
926
927   xover = 2;
928
929   if (sa < xover)
930      PlainSqr(cp, ap, sa);
931   else {
932      /* karatsuba */
933
934      long n, hn, sp, depth;
935
936      n = sa;
937      sp = 0;
938      depth = 0;
939      do {
940         hn = (n+1) >> 1;
941         sp += hn+hn+hn - 1;
942         n = hn;
943         depth++;
944      } while (n >= xover);
945
946      ZZVec stk;
947      stk.SetSize(sp,
948         ((2*maxa + NumBits(sa) + 2*depth + 10)
949          + NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
950
951      KarSqr(cp, ap, sa, stk.elts());
952   }
953
954   c.normalize();
955}
956
957NTL_END_IMPL
Note: See TracBrowser for help on using the repository browser.