matmul.h
1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  * http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 #ifndef HELIB_MATMUL_H
13 #define HELIB_MATMUL_H
14 
15 #include <helib/EncryptedArray.h>
16 #include <functional>
17 
18 namespace helib {
19 
20 class MatMulFullExec;
21 
22 // Abstract base class for representing a linear transformation on a full
23 // std::vector.
25 {
26 public:
27  virtual ~MatMulFull() {}
28  virtual const EncryptedArray& getEA() const = 0;
30 };
31 
32 // Concrete derived class that defines the matrix entries.
33 template <typename type>
35 {
36 public:
37  PA_INJECT(type)
38 
39  // Get (i, j) entry of matrix.
40  // Should return true when the entry is a zero.
41  virtual bool get(RX& out, long i, long j) const = 0;
42 };
43 
44 //====================================
45 
47 
48 // Abstract base class for representing a block linear transformation on a full
49 // std::vector.
51 {
52 public:
53  virtual ~BlockMatMulFull() {}
54  virtual const EncryptedArray& getEA() const = 0;
56 };
57 
58 // Concrete derived class that defines the matrix entries.
59 template <typename type>
61 {
62 public:
63  PA_INJECT(type)
64 
65  // Get (i, j) entry of matrix.
66  // Each entry is a d x d matrix over the base ring.
67  // Should return true when the entry is a zero.
68  virtual bool get(mat_R& out, long i, long j) const = 0;
69 };
70 
71 //====================================
72 
73 class MatMul1DExec;
74 
75 // Abstract base class for representing a 1D linear transformation.
76 class MatMul1D
77 {
78 public:
79  virtual ~MatMul1D() {}
80  virtual const EncryptedArray& getEA() const = 0;
81  virtual long getDim() const = 0;
83 };
84 
85 // An intermediate class that is mainly intended for internal use.
86 template <typename type>
87 class MatMul1D_partial : public MatMul1D
88 {
89 public:
90  PA_INJECT(type)
91 
92  // Get the i'th diagonal, encoded as a single constant.
93  // MatMul1D_derived (below) supplies a default implementation,
94  // which can be overridden in special circumstances.
95  virtual void processDiagonal(RX& poly,
96  long i,
97  const EncryptedArrayDerived<type>& ea) const = 0;
98 };
99 
100 // Concrete derived class that defines the matrix entries.
101 template <typename type>
102 class MatMul1D_derived : public MatMul1D_partial<type>
103 {
104 public:
105  PA_INJECT(type)
106 
107  // Should return true if their are multiple (different) transforms
108  // among the various components.
109  virtual bool multipleTransforms() const = 0;
110 
111  // Get coordinate (i, j) of the kth component.
112  // Should return true when the entry is a zero.
113  virtual bool get(RX& out, long i, long j, long k) const = 0;
114 
115  void processDiagonal(RX& poly,
116  long i,
117  const EncryptedArrayDerived<type>& ea) const override;
118 };
119 
120 template <>
121 class MatMul1D_derived<PA_cx> : public MatMul1D
122 {
123 public:
124  // Get coordinate (i, j)
125  virtual std::complex<double> get(long i, long j) const = 0;
126 
127  void processDiagonal(std::vector<std::complex<double>>& diag,
128  long i,
129  const EncryptedArrayCx& ea) const;
130 
131  // final: ensures that dim==0 is the only possible dimension
132  virtual long getDim() const final { return 0; }
133 };
134 
136 
137 // more convenient user interfaces
138 // VJS-FIXME: document some of this stuff
139 
141 {
142 public:
143  typedef std::function<double(long, long)> get_fun_type;
144 
145 private:
146  const EncryptedArray& ea;
147 
148  get_fun_type get_fun;
149  // get_fun(i,j) returns matrix entry (i,j)
150  // see get_fun_type definitions below
151 
152 public:
153  MatMul_CKKS(const EncryptedArray& _ea, get_fun_type _get_fun) :
154  ea(_ea), get_fun(_get_fun)
155  {}
156 
157  MatMul_CKKS(const Context& context, get_fun_type _get_fun) :
158  ea(context.getEA()), get_fun(_get_fun)
159  {}
160 
161  virtual const EncryptedArray& getEA() const override { return ea; }
162 
163  virtual std::complex<double> get(long i, long j) const override
164  {
165  return get_fun(i, j);
166  }
167 };
168 
170 {
171 public:
172  typedef std::function<std::complex<double>(long, long)> get_fun_type;
173 
174 private:
175  const EncryptedArray& ea;
176 
177  get_fun_type get_fun;
178  // get_fun(i,j) returns matrix entry (i,j)
179  // see get_fun_type definitions below
180 
181 public:
183  ea(_ea), get_fun(_get_fun)
184  {}
185 
186  MatMul_CKKS_Complex(const Context& context, get_fun_type _get_fun) :
187  ea(context.getEA()), get_fun(_get_fun)
188  {}
189 
190  virtual const EncryptedArray& getEA() const override { return ea; }
191 
192  virtual std::complex<double> get(long i, long j) const override
193  {
194  return get_fun(i, j);
195  }
196 };
197 
198 //====================================
199 
200 class BlockMatMul1DExec;
201 
202 // Abstract base class for representing a block 1D linear transformation.
204 {
205 public:
206  virtual ~BlockMatMul1D() {}
207  virtual const EncryptedArray& getEA() const = 0;
208  virtual long getDim() const = 0;
210 };
211 
212 // An intermediate class that is mainly intended for internal use.
213 template <typename type>
215 {
216 public:
217  PA_INJECT(type)
218 
219  // Get the i'th diagonal, encoded as a std::vector of d constants,
220  // where d is the order of p.
221  // BlockMatMul1D_derived (below) supplies a default implementation,
222  // which can be overridden in special circumstances.
223  virtual bool processDiagonal(std::vector<RX>& poly,
224  long i,
225  const EncryptedArrayDerived<type>& ea) const = 0;
226 };
227 
228 // Concrete derived class that defines the matrix entries.
229 template <typename type>
231 {
232 public:
233  PA_INJECT(type)
234 
235  // Should return true if their are multiple (different) transforms
236  // among the various components.
237  virtual bool multipleTransforms() const = 0;
238 
239  // Get coordinate (i, j) of the kth component.
240  // Each entry is a d x d matrix over the base ring.
241  // Should return true when the entry is a zero.
242  virtual bool get(mat_R& out, long i, long j, long k) const = 0;
243 
244  bool processDiagonal(std::vector<RX>& poly,
245  long i,
246  const EncryptedArrayDerived<type>& ea) const override;
247 };
248 
249 //====================================
250 
251 struct ConstMultiplier;
252 // Defined in matmul.cpp.
253 // Holds a constant by which a ciphertext can be multiplied.
254 // Internally, it is represented as either zzX or a DoubleCRT.
255 // The former occupies less space, but the latter makes for
256 // much faster multiplication.
257 
259 {
260  std::vector<std::shared_ptr<ConstMultiplier>> multiplier;
261 
262  // Upgrade zzX constants to DoubleCRT constants.
263  void upgrade(const Context& context);
264 };
265 
266 //====================================
267 
268 // Abstract base case for multiplying an encrypted std::vector by a plaintext
269 // matrix.
271 {
272 public:
273  virtual ~MatMulExecBase() {}
274 
275  virtual const EncryptedArray& getEA() const = 0;
276 
277  // Upgrade zzX constants to DoubleCRT constants.
278  virtual void upgrade() = 0;
279 
280  // If ctxt encrypts a row std::vector v, then this replaces ctxt
281  // by an encryption of the row std::vector v*mat, where mat is
282  // a matrix provided to the constructor of one of the
283  // concrete subclasses MatMul1DExec, BlockMatMul1DExec,
284  // MatMulFullExec, BlockMatMulFullExec, defined below.
285  virtual void mul(Ctxt& ctxt) const = 0;
286 };
287 
288 //====================================
289 
290 // Class used to multiply an encrypted row std::vector by a 1D linear
291 // transformation.
293 {
294 public:
296 
297  long dim;
298  long D;
299  bool native;
300  bool minimal;
301  long g;
302 
304  ConstMultiplierCache cache1; // only for non-native dimension
305 
306  // The constructor encodes all the constants for a given
307  // matrix in zzX format.
308  // The mat argument defines the entries of the matrix.
309  // Use the upgrade method (below) to convert to DoubleCRT format.
310  // If the minimal flag is set to true, a strategy that relies
311  // on a minimal number of key switching matrices will be used;
312  // this is intended for use in conjunction with the
313  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
314  // If the minimal flag is false, it is best to use the
315  // addSome{1D,Frb}Matrices routines declared in helib.h.
316  explicit MatMul1DExec(const MatMul1D& mat, bool minimal = false);
317 
318  // VJS-FIXME: it seems that the minimal flag is currently
319  // redundant, as the decision is essentially based on
320  // ctxt.getPubKey().getKSStrategy(dim0). Need to look into this
321  // and re-assess.
322 
323  // Replaces an encryption of row std::vector v by encryption of v*mat
324  void mul(Ctxt& ctxt) const override;
325 
326  // Upgrades encoded constants from zzX to DoubleCRT.
327  void upgrade() override
328  {
329  cache.upgrade(ea.getContext());
330  cache1.upgrade(ea.getContext());
331  }
332 
333  const EncryptedArray& getEA() const override { return ea; }
334 };
335 
336 // A more convenient and naturally-named interface for CKKS
337 // VJS-FIXME: document some of this stuff
338 
340 {
341 public:
343 };
344 
345 //====================================
346 
347 // Class used to multiply an encrypted row std::vector by a block 1D linear
348 // transformation.
350 {
351 public:
353 
354  long dim;
355  long D;
356  long d;
357  bool native;
358  long strategy;
359 
361  ConstMultiplierCache cache1; // only for non-native dimension
362 
363  // The constructor encodes all the constants for a given
364  // matrix in zzX format.
365  // The mat argument defines the entries of the matrix.
366  // Use the upgrade method (below) to convert to DoubleCRT format.
367  // If the minimal flag is set to true, a strategy that relies
368  // on a minimal number of key switching matrices will be used;
369  // this is intended for use in conjunction with the
370  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
371  // If the minimal flag is false, it is best to use the
372  // addSome{1D,Frb}Matrices routines declared in helib.h.
373  explicit BlockMatMul1DExec(const BlockMatMul1D& mat, bool minimal = false);
374 
375  // Replaces an encryption of row std::vector v by encryption of v*mat
376  void mul(Ctxt& ctxt) const override;
377 
378  // Upgrades encoded constants from zzX to DoubleCRT.
379  void upgrade() override
380  {
381  cache.upgrade(ea.getContext());
382  cache1.upgrade(ea.getContext());
383  }
384 
385  const EncryptedArray& getEA() const override { return ea; }
386 };
387 
388 //====================================
389 
390 // Class used to multiply an encrypted row std::vector by a full linear
391 // transformation.
393 {
394 public:
396  bool minimal;
397  std::vector<long> dims;
398  std::vector<MatMul1DExec> transforms;
399 
400  // The constructor encodes all the constants for a given
401  // matrix in zzX format.
402  // The mat argument defines the entries of the matrix.
403  // Use the upgrade method (below) to convert to DoubleCRT format.
404  // If the minimal flag is set to true, a strategy that relies
405  // on a minimal number of key switching matrices will be used;
406  // this is intended for use in conjunction with the
407  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
408  // If the minimal flag is false, it is best to use the
409  // addSome{1D,Frb}Matrices routines declared in helib.h.
410  explicit MatMulFullExec(const MatMulFull& mat, bool minimal = false);
411 
412  // Replaces an encryption of row std::vector v by encryption of v*mat
413  void mul(Ctxt& ctxt) const override;
414 
415  // Upgrades encoded constants from zzX to DoubleCRT.
416  void upgrade() override
417  {
418  for (auto& t : transforms)
419  t.upgrade();
420  }
421 
422  const EncryptedArray& getEA() const override { return ea; }
423 
424  // This really should be private.
425  long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
426 };
427 
428 //====================================
429 
430 // Class used to multiply an encrypted row std::vector by a full block linear
431 // transformation.
433 {
434 public:
436  bool minimal;
437  std::vector<long> dims;
438  std::vector<BlockMatMul1DExec> transforms;
439 
440  // The constructor encodes all the constants for a given
441  // matrix in zzX format.
442  // The mat argument defines the entries of the matrix.
443  // Use the upgrade method (below) to convert to DoubleCRT format.
444  // If the minimal flag is set to true, a strategy that relies
445  // on a minimal number of key switching matrices will be used;
446  // this is intended for use in conjunction with the
447  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
448  // If the minimal flag is false, it is best to use the
449  // addSome{1D,Frb}Matrices routines declared in helib.h.
450  explicit BlockMatMulFullExec(const BlockMatMulFull& mat,
451  bool minimal = false);
452 
453  // Replaces an encryption of row std::vector v by encryption of v*mat
454  void mul(Ctxt& ctxt) const override;
455 
456  // Upgrades encoded constants from zzX to DoubleCRT.
457  void upgrade() override
458  {
459  for (auto& t : transforms)
460  t.upgrade();
461  }
462 
463  const EncryptedArray& getEA() const override { return ea; }
464 
465  // This really should be private.
466  long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
467 };
468 
469 //===================================
470 
471 // ctxt = \sum_{i=0}^{d-1} \sigma^i(ctxt),
472 // where d = order of p mod m, and \sigma is the Frobenius map
473 
474 void traceMap(Ctxt& ctxt);
475 
476 //====================================
477 
478 // These routines apply linear transformation to plaintext arrays.
479 // Mainly for testing purposes.
480 void mul(PlaintextArray& pa, const MatMul1D& mat);
481 void mul(PlaintextArray& pa, const BlockMatMul1D& mat);
482 void mul(PlaintextArray& pa, const MatMulFull& mat);
483 void mul(PlaintextArray& pa, const BlockMatMulFull& mat);
484 
485 // VJS-FIXME: these should be documented
486 
487 inline void mul(PtxtArray& a, const MatMul1D& mat)
488 {
489  assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
490  mul(a.pa, mat);
491 }
492 
493 inline void mul(PtxtArray& a, const BlockMatMul1D& mat)
494 {
495  assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
496  mul(a.pa, mat);
497 }
498 
499 inline void mul(PtxtArray& a, const MatMulFull& mat)
500 {
501  assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
502  mul(a.pa, mat);
503 }
504 
505 inline void mul(PtxtArray& a, const BlockMatMulFull& mat)
506 {
507  assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
508  mul(a.pa, mat);
509 }
510 
511 // more interface conviences, both for PtxtArray and Ctxt
512 
513 inline PtxtArray& operator*=(PtxtArray& a, const MatMul1D& mat)
514 {
515  mul(a, mat);
516  return a;
517 }
518 
520 {
521  mul(a, mat);
522  return a;
523 }
524 
525 inline PtxtArray& operator*=(PtxtArray& a, const MatMulFull& mat)
526 {
527  mul(a, mat);
528  return a;
529 }
530 
532 {
533  mul(a, mat);
534  return a;
535 }
536 
537 // For ctxt's, these functions don't do any pre-computation
538 
539 inline Ctxt& operator*=(Ctxt& a, const MatMul1D& mat)
540 {
541  MatMul1DExec mat_exec(mat);
542  mat_exec.mul(a);
543  return a;
544 }
545 
546 inline Ctxt& operator*=(Ctxt& a, const BlockMatMul1D& mat)
547 {
548  BlockMatMul1DExec mat_exec(mat);
549  mat_exec.mul(a);
550  return a;
551 }
552 
553 inline Ctxt& operator*=(Ctxt& a, const MatMulFull& mat)
554 {
555  MatMulFullExec mat_exec(mat);
556  mat_exec.mul(a);
557  return a;
558 }
559 
560 inline Ctxt& operator*=(Ctxt& a, const BlockMatMulFull& mat)
561 {
562  BlockMatMulFullExec mat_exec(mat);
563  mat_exec.mul(a);
564  return a;
565 }
566 
567 // For ctxt's, these functions do allow pre-computation
568 
569 inline Ctxt& operator*=(Ctxt& a, const MatMulExecBase& mat)
570 {
571  mat.mul(a);
572  return a;
573 }
574 
575 // These are used mainly for performance evaluation.
576 
577 extern int fhe_test_force_bsgs;
578 // Controls whether or not we use BSGS multiplication.
579 // 1 to force on, -1 to force off, 0 for default behaviour.
580 
581 extern int fhe_test_force_hoist;
582 // Controls whether ot not we use hoisting.
583 // -1 to force off, 0 for default behaviour.
584 
585 } // namespace helib
586 
587 #endif // ifndef HELIB_MATMUL_H
Definition: matmul.h:231
Definition: matmul.h:215
Definition: matmul.h:350
ConstMultiplierCache cache1
Definition: matmul.h:361
long D
Definition: matmul.h:355
ConstMultiplierCache cache
Definition: matmul.h:360
const EncryptedArray & ea
Definition: matmul.h:352
void upgrade() override
Definition: matmul.h:379
long dim
Definition: matmul.h:354
long strategy
Definition: matmul.h:358
long d
Definition: matmul.h:356
const EncryptedArray & getEA() const override
Definition: matmul.h:385
bool native
Definition: matmul.h:357
void mul(Ctxt &ctxt) const override
Definition: matmul.cpp:1697
Definition: matmul.h:204
virtual const EncryptedArray & getEA() const =0
BlockMatMul1DExec ExecType
Definition: matmul.h:209
virtual long getDim() const =0
virtual ~BlockMatMul1D()
Definition: matmul.h:206
Definition: matmul.h:61
Definition: matmul.h:433
std::vector< BlockMatMul1DExec > transforms
Definition: matmul.h:438
std::vector< long > dims
Definition: matmul.h:437
void mul(Ctxt &ctxt) const override
Definition: matmul.cpp:2597
bool minimal
Definition: matmul.h:436
const EncryptedArray & getEA() const override
Definition: matmul.h:463
void upgrade() override
Definition: matmul.h:457
const EncryptedArray & ea
Definition: matmul.h:435
Definition: matmul.h:51
virtual const EncryptedArray & getEA() const =0
BlockMatMulFullExec ExecType
Definition: matmul.h:55
virtual ~BlockMatMulFull()
Definition: matmul.h:53
Maintaining the HE scheme parameters.
Definition: Context.h:100
A Ctxt object holds a single ciphertext.
Definition: Ctxt.h:396
Definition: matmul.h:340
EncodedMatMul_CKKS(const MatMul1D_CKKS &mat)
Definition: matmul.h:342
A different derived class to be used for the approximate-numbers scheme.
Definition: EncryptedArray.h:880
Derived concrete implementation of EncryptedArrayBase.
Definition: EncryptedArray.h:403
A simple wrapper for a smart pointer to an EncryptedArrayBase. This is the interface that higher-leve...
Definition: EncryptedArray.h:1583
const Context & getContext() const
Definition: EncryptedArray.h:1668
Definition: matmul.h:122
virtual long getDim() const final
Definition: matmul.h:132
virtual std::complex< double > get(long i, long j) const =0
Definition: matmul.h:103
Definition: matmul.h:88
Definition: matmul.h:293
bool native
Definition: matmul.h:299
const EncryptedArray & ea
Definition: matmul.h:295
const EncryptedArray & getEA() const override
Definition: matmul.h:333
ConstMultiplierCache cache1
Definition: matmul.h:304
void upgrade() override
Definition: matmul.h:327
void mul(Ctxt &ctxt) const override
Definition: matmul.cpp:973
long g
Definition: matmul.h:301
long dim
Definition: matmul.h:297
long D
Definition: matmul.h:298
ConstMultiplierCache cache
Definition: matmul.h:303
bool minimal
Definition: matmul.h:300
Definition: matmul.h:77
virtual ~MatMul1D()
Definition: matmul.h:79
virtual const EncryptedArray & getEA() const =0
MatMul1DExec ExecType
Definition: matmul.h:82
virtual long getDim() const =0
Definition: matmul.h:170
MatMul_CKKS_Complex(const Context &context, get_fun_type _get_fun)
Definition: matmul.h:186
MatMul_CKKS_Complex(const EncryptedArray &_ea, get_fun_type _get_fun)
Definition: matmul.h:182
std::function< std::complex< double >long, long)> get_fun_type
Definition: matmul.h:172
virtual std::complex< double > get(long i, long j) const override
Definition: matmul.h:192
virtual const EncryptedArray & getEA() const override
Definition: matmul.h:190
Definition: matmul.h:141
virtual const EncryptedArray & getEA() const override
Definition: matmul.h:161
MatMul_CKKS(const EncryptedArray &_ea, get_fun_type _get_fun)
Definition: matmul.h:153
virtual std::complex< double > get(long i, long j) const override
Definition: matmul.h:163
MatMul_CKKS(const Context &context, get_fun_type _get_fun)
Definition: matmul.h:157
std::function< double(long, long)> get_fun_type
Definition: matmul.h:143
Definition: matmul.h:271
virtual const EncryptedArray & getEA() const =0
virtual void mul(Ctxt &ctxt) const =0
virtual ~MatMulExecBase()
Definition: matmul.h:273
virtual void upgrade()=0
Definition: matmul.h:35
virtual bool get(RX &out, long i, long j) const =0
Definition: matmul.h:393
const EncryptedArray & getEA() const override
Definition: matmul.h:422
const EncryptedArray & ea
Definition: matmul.h:395
void mul(Ctxt &ctxt) const override
Definition: matmul.cpp:2254
void upgrade() override
Definition: matmul.h:416
std::vector< MatMul1DExec > transforms
Definition: matmul.h:398
bool minimal
Definition: matmul.h:396
std::vector< long > dims
Definition: matmul.h:397
Definition: matmul.h:25
virtual const EncryptedArray & getEA() const =0
MatMulFullExec ExecType
Definition: matmul.h:29
virtual ~MatMulFull()
Definition: matmul.h:27
Definition: EncryptedArray.h:2167
PlaintextArray pa
Definition: EncryptedArray.h:2178
const EncryptedArray & ea
Definition: EncryptedArray.h:2177
Definition: apiAttributes.h:21
void mul(const EncryptedArray &ea, PlaintextArray &pa, const PlaintextArray &other)
Definition: EncryptedArray.cpp:1612
MatMul1D_derived< PA_cx > MatMul1D_CKKS
Definition: matmul.h:135
int fhe_test_force_hoist
Definition: matmul.cpp:24
int fhe_test_force_bsgs
Definition: matmul.cpp:23
void assertTrue(const T &value, const std::string &message)
Definition: assertions.h:61
PtxtArray & operator*=(PtxtArray &a, const PtxtArray &b)
Definition: EncryptedArray.h:2468
void traceMap(Ctxt &ctxt)
Definition: matmul.cpp:2865
Definition: io.h:50
Definition: matmul.h:259
void upgrade(const Context &context)
Definition: matmul.cpp:410
std::vector< std::shared_ptr< ConstMultiplier > > multiplier
Definition: matmul.h:260
Definition: matmul.cpp:318