auto import from //depot/cupcake/@135843
[android/platform/external/neven.git] / Embedded / common / src / b_TensorEm / Int32Mat.c
1 /*
2  * Copyright (C) 2008 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /* ---- includes ----------------------------------------------------------- */
18
19 #include "b_TensorEm/Int32Mat.h"
20 #include "b_TensorEm/Functions.h"
21 #include "b_BasicEm/Math.h"
22 #include "b_BasicEm/Functions.h"
23 #include "b_BasicEm/Memory.h"
24
25 /* ------------------------------------------------------------------------- */
26
27 /* ========================================================================= */
28 /*                                                                           */
29 /* ---- \ghd{ auxiliary functions } ---------------------------------------- */
30 /*                                                                           */
31 /* ========================================================================= */
32
33 /* ------------------------------------------------------------------------- */
34
35 void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA )
36 {
37         int32 shiftL;
38
39         /* find max element */
40         int32 maxL = 0;
41         int32* ptrL = ptrA;
42         int32 iL = sizeA;
43         while( iL-- )
44         {
45                 int32 xL = *ptrL++;
46                 if( xL < 0 ) xL = -xL;
47                 if( xL > maxL ) maxL = xL;
48         }
49
50         /* determine shift */
51         shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA;
52
53         if( shiftL > 0 )
54         {
55                 ptrL = ptrA;
56                 iL = sizeA;
57                 while( iL-- )
58                 {
59                         *ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1;
60                         ptrL++;
61                 }
62
63                 *bbpPtrA -= shiftL;
64         }
65 }
66
67 /* ------------------------------------------------------------------------- */
68
69 /* ========================================================================= */
70 /*                                                                           */
71 /* ---- \ghd{ constructor / destructor } ----------------------------------- */
72 /*                                                                           */
73 /* ========================================================================= */
74
75 /* ------------------------------------------------------------------------- */
76
77 void bts_Int32Mat_init( struct bbs_Context* cpA,
78                                             struct bts_Int32Mat* ptrA )
79 {
80         ptrA->widthE = 0;
81         bbs_Int32Arr_init( cpA, &ptrA->arrE );
82 }
83
84 /* ------------------------------------------------------------------------- */
85
86 void bts_Int32Mat_exit( struct bbs_Context* cpA,
87                                             struct bts_Int32Mat* ptrA )
88 {
89         ptrA->widthE = 0;
90         bbs_Int32Arr_exit( cpA, &ptrA->arrE );
91 }
92 /* ------------------------------------------------------------------------- */
93
94 /* ========================================================================= */
95 /*                                                                           */
96 /* ---- \ghd{ operators } -------------------------------------------------- */
97 /*                                                                           */
98 /* ========================================================================= */
99
100 /* ------------------------------------------------------------------------- */
101
102 /* ========================================================================= */
103 /*                                                                           */
104 /* ---- \ghd{ query functions } -------------------------------------------- */
105 /*                                                                           */
106 /* ========================================================================= */
107
108 /* ------------------------------------------------------------------------- */
109
110 /* ========================================================================= */
111 /*                                                                           */
112 /* ---- \ghd{ modify functions } ------------------------------------------- */
113 /*                                                                           */
114 /* ========================================================================= */
115
116 /* ------------------------------------------------------------------------- */
117         
118 void bts_Int32Mat_create( struct bbs_Context* cpA,
119                                                   struct bts_Int32Mat* ptrA, 
120                                                   int32 widthA,
121                                           struct bbs_MemSeg* mspA )
122 {
123         if( bbs_Context_error( cpA ) ) return;
124         bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA );
125         ptrA->widthE = widthA;
126 }
127
128 /* ------------------------------------------------------------------------- */
129         
130 void bts_Int32Mat_copy( struct bbs_Context* cpA,
131                                             struct bts_Int32Mat* ptrA, 
132                                                 const struct bts_Int32Mat* srcPtrA )
133 {
134         if( ptrA->widthE != srcPtrA->widthE )
135         {
136                 bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n"
137                                "size mismatch" );
138                 return;
139         }
140
141         bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE );
142 }
143
144 /* ------------------------------------------------------------------------- */
145         
146 /* ========================================================================= */
147 /*                                                                           */
148 /* ---- \ghd{ I/O } -------------------------------------------------------- */
149 /*                                                                           */
150 /* ========================================================================= */
151
152 /* ------------------------------------------------------------------------- */
153         
154 uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA,
155                                                          const struct bts_Int32Mat *ptrA )
156 {
157         return  bbs_SIZEOF16( uint32 )
158                   + bbs_SIZEOF16( uint32 ) /* version */
159                   + bbs_SIZEOF16( ptrA->widthE ) 
160                   + bbs_Int32Arr_memSize( cpA, &ptrA->arrE );
161 }
162
163 /* ------------------------------------------------------------------------- */
164         
165 uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA,
166                                                           const struct bts_Int32Mat* ptrA, 
167                                                           uint16* memPtrA )
168 {
169         uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA );
170         memPtrA += bbs_memWrite32( &memSizeL, memPtrA );
171         memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA );
172         memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA );
173         memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA );
174         return memSizeL;
175 }
176
177 /* ------------------------------------------------------------------------- */
178         
179 uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA,
180                                                          struct bts_Int32Mat* ptrA, 
181                                                          const uint16* memPtrA,
182                                              struct bbs_MemSeg* mspA )
183 {
184         uint32 memSizeL, versionL;
185         if( bbs_Context_error( cpA ) ) return 0;
186         memPtrA += bbs_memRead32( &memSizeL, memPtrA );
187         memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA );
188         memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA );
189         memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA );
190
191         if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) )
192         {
193                 bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n"
194                   "size mismatch" ); 
195         }
196         return memSizeL;
197 }
198
199 /* ------------------------------------------------------------------------- */
200         
201 /* ========================================================================= */
202 /*                                                                           */
203 /* ---- \ghd{ exec functions } --------------------------------------------- */
204 /*                                                                           */
205 /* ========================================================================= */
206
207 /* ------------------------------------------------------------------------- */
208
209 flag bts_Int32Mat_solve( struct bbs_Context* cpA,
210                                                  const int32* matA,
211                                                  int32 matWidthA,
212                                                  const int32* inVecA,
213                                                  int32* outVecA,
214                                                  int32 bbpA,
215                                                  int32* tmpMatA,
216                                                  int32* tmpVecA )
217 {
218         bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) );
219
220         return bts_Int32Mat_solve2( cpA, 
221                                         tmpMatA,
222                                                                 matWidthA,
223                                                                 inVecA,
224                                                                 outVecA,
225                                                                 bbpA,
226                                                                 tmpVecA );
227 }
228
229 /* ------------------------------------------------------------------------- */
230
231 flag bts_Int32Mat_solve2( struct bbs_Context* cpA,
232                                                   int32* matA,
233                                                   int32 matWidthA,
234                                                   const int32* inVecA,
235                                                   int32* outVecA,
236                                                   int32 bbpA,
237                                                   int32* tmpVecA )
238 {
239         int32 sizeL = matWidthA;
240         int32 bbpL = bbpA;
241         int32 iL, jL, kL;
242         int32 iPivL;
243         int32 jPivL;
244
245         int32* vecL      = outVecA;
246         int32* matL      = matA;
247         int32* checkArrL = tmpVecA;
248
249         for( iL = 0; iL < sizeL; iL++ )
250         {
251                 checkArrL[ iL ] = 0;
252         }
253         
254         bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) );
255
256         iPivL = 0;
257
258         for( kL = 0; kL < sizeL; kL++ )
259         {
260                 /* find pivot */
261                 int32 maxAbsL = 0;
262                 int32* pivRowL;
263
264                 int32 bbp_pivRowL, bbp_vecL, shiftL;
265
266                 jPivL = -1;
267                 for( iL = 0; iL < sizeL; iL++ )
268                 {
269                         if( checkArrL[ iL ] != 1 )
270                         {
271                                 int32* rowL = matL + ( iL * sizeL );
272                                 for( jL = 0; jL < sizeL; jL++ )
273                                 {
274                                         if( checkArrL[ jL ] == 0 )
275                                         {
276                                                 int32 absElemL = rowL[ jL ];
277                                                 if( absElemL < 0 ) absElemL = -absElemL;
278                                                 if( maxAbsL < absElemL )
279                                                 {
280                                                         maxAbsL = absElemL;
281                                                         iPivL = iL;
282                                                         jPivL = jL;
283                                                 }
284                                         } 
285                                         else if( checkArrL[ jL ] > 1 )
286                                         {
287                                                 return FALSE;
288                                         }
289                                 }
290                         }
291                 }
292
293                 /* successfull ? */
294                 if( jPivL < 0 )
295                 {
296                         return FALSE;
297                 }
298
299                 checkArrL[ jPivL ]++; 
300
301                 /* exchange rows to put pivot on diagonal, if neccessary */
302                 if( iPivL != jPivL )
303                 {
304                         int32* row1PtrL = matL + ( iPivL * sizeL );
305                         int32* row2PtrL = matL + ( jPivL * sizeL );
306                         for( jL = 0; jL < sizeL; jL++ )
307                         {
308                                 int32 tmpL = *row1PtrL;
309                                 *row1PtrL++ = *row2PtrL;
310                                 *row2PtrL++ = tmpL;
311                         }
312
313                         {
314                                 int32 tmpL = vecL[ jPivL ];
315                                 vecL[ jPivL ] = vecL[ iPivL ];
316                                 vecL[ iPivL ] = tmpL;
317                         }
318                 }
319                 /* now index jPivL specifies pivot row and maximum element */
320
321
322                 /**     Overflow protection: only if the highest bit of the largest matrix element is set,
323                  *      we need to shift the whole matrix and the right side vector 1 bit to the right,
324                  *      to make sure there can be no overflow when the pivot row gets subtracted from the
325                  *      other rows.
326                  *      Getting that close to overflow is a rare event, so this shift will happen only 
327                  *      occasionally, or not at all.
328                  */
329                 if( maxAbsL & 1073741824 )  /*( 1 << 30 )*/
330                 {
331                         /* right shift matrix by 1 */
332                         int32 iL = sizeL * sizeL;
333                         int32* ptrL = matL;
334                         while( iL-- )
335                         {
336                                 *ptrL = ( *ptrL + 1 ) >> 1;
337                                 ptrL++;
338                         }
339
340                         /* right shift right side vector by 1 */
341                         iL = sizeL;
342                         ptrL = vecL;
343                         while( iL-- )
344                         {
345                                 *ptrL = ( *ptrL + 1 ) >> 1;
346                                 ptrL++;
347                         }
348
349                         /* decrement bbpL */
350                         bbpL--;
351                 }
352
353
354                 /* reduce elements of pivot row to 15 bit */
355                 pivRowL = matL + jPivL * sizeL;
356                 bbp_pivRowL = bbpL;
357                 bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 );
358
359                 /* scale pivot row such that maximum equals 1 */
360                 {
361                         int32 maxL = pivRowL[ jPivL ];
362                         int32 bbp_maxL = bbp_pivRowL;
363                         int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/
364
365                         for( jL = 0; jL < sizeL; jL++ )
366                         {
367                                 pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15;
368                         }
369                         bbp_pivRowL = 15;
370
371                         /* set to 1 to avoid computational errors */
372                         pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL; 
373
374                         shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] );
375
376                         vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL;
377                         bbp_vecL = bbpL + shiftL - bbp_maxL;
378
379                         bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 );
380                 }
381
382                 /* subtract pivot row from all other rows */
383                 for( iL = 0; iL < sizeL; iL++ )
384                 {
385                         if( iL != jPivL )
386                         {
387                                 int32* rowPtrL = matL + iL * sizeL;
388
389                                 int32 tmpL = *( rowPtrL + jPivL );
390                                 int32 bbp_tmpL = bbpL;
391                                 bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 );
392
393                                 shiftL = bbp_tmpL + bbp_pivRowL - bbpL;
394                                 if( shiftL > 0 )
395                                 {
396                                         for( jL = 0; jL < sizeL; jL++ )
397                                         {
398                                                 *rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
399                                         }
400                                 }
401                                 else
402                                 {
403                                         for( jL = 0; jL < sizeL; jL++ )
404                                         {
405                                                 *rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL;
406                                         }
407                                 }
408
409                                 shiftL = bbp_tmpL + bbp_vecL - bbpL;
410                                 if( shiftL > 0 )
411                                 {
412                                         vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
413                                 }
414                                 else
415                                 {
416                                         vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL;
417                                 }
418                         }
419                 }
420
421                 /* change bbp of pivot row back to bbpL */
422                 shiftL = bbpL - bbp_pivRowL;
423                 if( shiftL >= 0 )
424                 {
425                         for( jL = 0; jL < sizeL; jL++ )
426                         {
427                                 pivRowL[ jL ] <<= shiftL;
428                         }
429                 }
430                 else
431                 {
432                         shiftL = -shiftL;
433                         for( jL = 0; jL < sizeL; jL++ )
434                         {
435                                 pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
436                         }
437                 }
438
439                 shiftL = bbpL - bbp_vecL;
440                 if( shiftL >= 0 )
441                 {
442                         vecL[ jPivL ] <<= shiftL;
443                 }
444                 else
445                 {
446                         shiftL = -shiftL;
447                         vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
448                 }
449 /*
450 if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL );
451 */
452         }       /* of kL */
453
454         /* in case bbpL has been decreased by the overflow protection, change it back now */
455         if( bbpA > bbpL )
456         {
457                 /* find largest element of solution vector */
458                 int32 maxL = 0;
459                 int32 iL, shiftL;
460                 for( iL = 0; iL < sizeL; iL++ )
461                 {
462                         int32 xL = vecL[ iL ];
463                         if( xL < 0 ) xL = -xL;
464                         if( xL > maxL ) maxL = xL;
465                 }
466                 
467                 /* check whether we can left shift without overflow */
468                 shiftL = 30 - bts_absIntLog2( maxL );
469                 if( shiftL < ( bbpA - bbpL ) )
470                 {
471                         /* 
472                             bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to "
473                                 "compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA );
474                         */
475
476                         return FALSE;
477                 }       
478
479                 /* shift left */
480                 shiftL = bbpA - bbpL;
481                 for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL;
482         }
483
484         return TRUE;
485 }
486
487 /* ------------------------------------------------------------------------- */
488
489 /* ========================================================================= */
490