pasholnahuy

Untitled

Dec 15th, 2023
127
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 5.02 KB | None | 0 0
  1. #include <stdbool.h>
  2. #include <stdint.h>
  3.  
  4. typedef uint16_t FP16;
  5.  
  6. enum {
  7.     FRAC_MASK = (1 << 10) - 1,
  8.     EXP_MASK = ((1 << 5) - 1) << 10,
  9.     SIGN_MASK = 1 << 15,
  10.     FRAC_WIDTH = 10,
  11.     FIXED_EXP = 24,
  12.     EXP_WIDTH = 5,
  13.     NAN = EXP_MASK + FRAC_MASK,
  14.     CONST24 = 24
  15. };
  16.  
  17. uint16_t is_nan(FP16 x) {
  18.     if (x & EXP_MASK == EXP_MASK && (x & FRAC_MASK != 0)) {
  19.         return 1;
  20.     }
  21.     return 0;
  22. }
  23.  
  24. uint16_t is_inf(FP16 x) {
  25.     if (x & EXP_MASK == EXP_MASK && (x & FRAC_MASK == 0)) {
  26.         return 1;
  27.     }
  28.     return 0;
  29. }
  30.  
  31. int64_t cast_fp16_to_fixed(FP16 x) {
  32.     int64_t res;
  33.     if (x & EXP_MASK) {
  34.         res = (int64_t)((1 << FRAC_WIDTH) | (x & FRAC_MASK)) << ((x & EXP_MASK) - 1);
  35.     } else {
  36.         res = x & FRAC_MASK;
  37.     }
  38.     if (x & SIGN_MASK) {
  39.         res = -res;
  40.     }
  41.     return res;
  42. }
  43.  
  44. FP16 cast_fixed_to_fp16(int64_t x) {
  45.     int u = 0;
  46.     if (x < 0) {
  47.         x = -x;
  48.         u = 1;
  49.     }
  50.     if (x <= FRAC_MASK) {
  51.         if (u){
  52.             return -x;
  53.         }
  54.         return x;
  55.     }
  56.  
  57.     int exp = 1;
  58.     while (x >= (1 << (FRAC_WIDTH + 2))) {
  59.         x = x >> 1;
  60.         ++exp;
  61.     }
  62.     if (x >= (1 << (FRAC_WIDTH + 1))) {
  63.         ++x;
  64.         while (x >= 1 << (FRAC_WIDTH + 1)) {
  65.             x = x >> 1;
  66.             ++exp;
  67.         }
  68.     }
  69.     if (exp >= (1 << EXP_WIDTH) - 1) {
  70.         if (u) {
  71.             return EXP_MASK | SIGN_MASK;
  72.         }
  73.         return EXP_MASK;
  74.     }
  75.     if (u) {
  76.         return SIGN_MASK | (exp << 10);
  77.     }
  78.     return (exp << 10) | (x & FRAC_MASK);
  79. }
  80.  
  81. uint16_t fp16_mul2(uint16_t x) {
  82.     if (is_nan(x) || is_inf(x)) {
  83.         return x;
  84.     }
  85.     return cast_fixed_to_fp16(cast_fp16_to_fixed(x) << 1);
  86. }
  87.  
  88. uint16_t fp16_div2(uint16_t x) {
  89.     if (is_nan(x) || is_inf(x)) {
  90.         return x;
  91.     }
  92.     return cast_fixed_to_fp16(cast_fp16_to_fixed(x) >> 1);
  93. }
  94.  
  95. uint16_t fp16_neg(uint16_t x) {
  96.     if (is_nan(x) || is_inf(x)) {
  97.         return x;
  98.     }
  99.     return cast_fixed_to_fp16(-cast_fp16_to_fixed(x));
  100. }
  101.  
  102. uint16_t fp16_add(uint16_t x, uint16_t y) {
  103.     if ((is_nan(x) || is_inf(x)) || (is_inf(x) && is_inf(y))) {
  104.             return NAN;
  105.         }
  106.     return cast_fixed_to_fp16(cast_fp16_to_fixed(x) + cast_fp16_to_fixed(x));
  107. }
  108.  
  109. int fp16_cmp(uint16_t x, uint16_t y) {
  110.     if (is_inf(x) && is_inf(y)) {
  111.         if (x & SIGN_MASK && y & SIGN_MASK ||
  112.             !(x & SIGN_MASK) && !(y & SIGN_MASK)) {
  113.             return 0;
  114.         } else if (x & SIGN_MASK && !(y & SIGN_MASK)) {
  115.             return 1;
  116.         } else if (x & SIGN_MASK && !(y & SIGN_MASK)) {
  117.             return -1;
  118.         }
  119.     } else if (is_inf(x) && !is_inf(y)) {
  120.         if (x & SIGN_MASK) {
  121.             return -1;
  122.         }
  123.         return 1;
  124.     } else if (!is_inf(x) && is_inf(y)) {
  125.         if (y & SIGN_MASK) {
  126.             return 1;
  127.         }
  128.         return -1;
  129.     }
  130.     if (cast_fp16_to_fixed(x) < cast_fp16_to_fixed(y)) {
  131.         return -1;
  132.     }
  133.     if (cast_fp16_to_fixed(x) == cast_fp16_to_fixed(y)) {
  134.         return 0;
  135.     }
  136.     return 1;
  137. }
  138.  
  139. uint16_t fp16_cast(unsigned int x) {
  140.     return cast_fixed_to_fp16((int64_t)x << CONST24);
  141. }
  142.  
  143. #include <assert.h>
  144. #include <stdint.h>
  145.  
  146. uint16_t fp16_cast(unsigned);
  147. uint16_t fp16_mul2(uint16_t);
  148. uint16_t fp16_div2(uint16_t);
  149. uint16_t fp16_neg(uint16_t);
  150. uint16_t fp16_add(uint16_t, uint16_t);
  151. int fp16_cmp(uint16_t, uint16_t);
  152.  
  153. int main() {
  154.     uint16_t x = fp16_cast(1);
  155.     assert(x == 0b0011110000000000);
  156.     uint16_t y = fp16_cast(2);
  157.     assert(y == 0b0100000000000000);
  158.     // uint64_t z = cast_fp16_to_fixed(y);
  159.     // assert(z == 2);
  160.     assert(fp16_div2(y) == x);
  161.     assert(fp16_mul2(x) == y);
  162.     assert(fp16_cmp(x, y) == -1);
  163.     assert(fp16_cmp(y, x) == 1);
  164.     assert(fp16_cmp(x, x) == 0);
  165.     assert(fp16_cmp(fp16_neg(x), fp16_neg(y)) == 1);
  166.     assert(fp16_cmp(fp16_neg(y), fp16_neg(x)) == -1);
  167.     assert(fp16_cmp(0, fp16_neg(0)) == 0);
  168.  
  169.     uint16_t three = fp16_add(x, y);
  170.     assert(three == 0b0100001000000000);
  171.  
  172.     uint16_t large = fp16_cast((1 << 16) - (1 << 4) - 1);
  173.     uint16_t inf = fp16_mul2(large);
  174.     assert(inf == 0b0111110000000000);
  175.     assert(fp16_mul2(inf) == inf);
  176.     assert(fp16_div2(inf) == inf);
  177.     assert(fp16_cmp(large, inf) == -1);
  178.     assert(fp16_cmp(fp16_neg(inf), large) == -1);
  179.     assert(fp16_add(inf, fp16_neg(inf)) == fp16_add(fp16_neg(inf), inf));
  180.     assert(fp16_add(fp16_neg(large), three) == fp16_neg(large));
  181.     assert(fp16_add(large, fp16_cast(15)) == large);
  182.     assert(fp16_add(large, fp16_cast(16)) == inf);
  183.  
  184.     uint16_t small = 0b0000000000000001;
  185.     assert(fp16_cmp(small, small) == 0);
  186.     assert(fp16_cmp(small, large) == -1);
  187.     assert(fp16_cmp(large, small) == 1);
  188.     assert(fp16_div2(small) == 0);
  189.     assert(fp16_add(fp16_neg(small), x) == x);
  190.     assert(fp16_mul2(small) == small << 1);
  191.  
  192.     uint16_t smallish = small << 9;
  193.     assert(fp16_mul2(smallish) == 0b000001 << 10);
  194.     assert(fp16_div2(fp16_mul2(smallish)) == smallish);
  195.     assert(fp16_div2(smallish) == smallish >> 1);
  196. }
Add Comment
Please, Sign In to add comment