Advertisement
pasholnahuy

Untitled

Dec 15th, 2023
700
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 5.01 KB | None | 0 0
  1. #include <stdbool.h>
  2. #include <stdint.h>
  3.  
  4. typedef uint16_t FP16;
  5.  
  6. enum {
  7.     FRAC_MASK = (1 << 10) - 1,
  8.     EXP_MASK = ((1 << 5) - 1) << 10,
  9.     SIGN_MASK = 1 << 15,
  10.     FRAC_WIDTH = 10,
  11.     FIXED_EXP = 24,
  12.     EXP_WIDTH = 5,
  13.     NAN = EXP_MASK + FRAC_MASK,
  14.     CONST24 = 24
  15. };
  16.  
  17. uint16_t is_nan(FP16 x) {
  18.     if (x & EXP_MASK == EXP_MASK && (x & FRAC_MASK != 0)) {
  19.         return 1;
  20.     }
  21.     return 0;
  22. }
  23.  
  24. uint16_t is_inf(FP16 x) {
  25.     if (x & EXP_MASK == EXP_MASK && (x & FRAC_MASK == 0)) {
  26.         return 1;
  27.     }
  28.     return 0;
  29. }
  30.  
  31. int64_t cast_fp16_to_fixed(FP16 x) {
  32.     int64_t res;
  33.     if (x & EXP_MASK) {
  34.         res = (int64_t)(((1 << FRAC_WIDTH) | (x & FRAC_MASK)) << ((x & EXP_MASK) - 1));
  35.     } else {
  36.         res = x & FRAC_MASK;
  37.     }
  38.     if (x & SIGN_MASK) {
  39.         res = -res;
  40.     }
  41.     return res;
  42. }
  43.  
  44. FP16 cast_fixed_to_fp16(int64_t x) {
  45.     if (x < 0) {
  46.         x = -x;
  47.     }
  48.     if (x <= FRAC_MASK) {
  49.         if (x < 0){
  50.             return -x;
  51.         }
  52.         return x;
  53.     }
  54.  
  55.     int exp = 1;
  56.     while (x >= (1 << (FRAC_WIDTH + 2))) {
  57.         x = x >> 1;
  58.         ++exp;
  59.     }
  60.     if (x >= (1 << (FRAC_WIDTH + 1))) {
  61.         ++x;
  62.         while (x >= 1 << (FRAC_WIDTH + 1)) {
  63.             x = x >> 1;
  64.             ++exp;
  65.         }
  66.     }
  67.     if (exp >= (1 << EXP_WIDTH) - 1) {
  68.         if (x < 0) {
  69.             return EXP_MASK | SIGN_MASK;
  70.         }
  71.         return EXP_MASK;
  72.     }
  73.     if (x < 0) {
  74.         return SIGN_MASK || (exp << 10);
  75.     }
  76.     return (exp << 10) | (x & FRAC_MASK);
  77. }
  78.  
  79. uint16_t fp16_mul2(uint16_t x) {
  80.     if (is_nan(x) || is_inf(x)) {
  81.         return x;
  82.     }
  83.     return cast_fixed_to_fp16(cast_fp16_to_fixed(x) << 1);
  84. }
  85.  
  86. uint16_t fp16_div2(uint16_t x) {
  87.     if (is_nan(x) || is_inf(x)) {
  88.         return x;
  89.     }
  90.     return cast_fixed_to_fp16(cast_fp16_to_fixed(x) >> 1);
  91. }
  92.  
  93. uint16_t fp16_neg(uint16_t x) {
  94.     if (is_nan(x) || is_inf(x)) {
  95.         return x;
  96.     }
  97.     return cast_fixed_to_fp16(-cast_fp16_to_fixed(x));
  98. }
  99.  
  100. uint16_t fp16_add(uint16_t x, uint16_t y) {
  101.     if ((is_nan(x) || is_inf(x)) || (is_inf(x) && is_inf(y))) {
  102.             return NAN;
  103.         }
  104.     return cast_fixed_to_fp16(cast_fp16_to_fixed(x) + cast_fp16_to_fixed(x));
  105. }
  106.  
  107. int fp16_cmp(uint16_t x, uint16_t y) {
  108.     if (is_inf(x) && is_inf(y)) {
  109.         if (x & SIGN_MASK && y & SIGN_MASK ||
  110.             !(x & SIGN_MASK) && !(y & SIGN_MASK)) {
  111.             return 0;
  112.         } else if (x & SIGN_MASK && !(y & SIGN_MASK)) {
  113.             return 1;
  114.         } else if (x & SIGN_MASK && !(y & SIGN_MASK)) {
  115.             return -1;
  116.         }
  117.     } else if (is_inf(x) && !is_inf(y)) {
  118.         if (x & SIGN_MASK) {
  119.             return -1;
  120.         }
  121.         return 1;
  122.     } else if (!is_inf(x) && is_inf(y)) {
  123.         if (y & SIGN_MASK) {
  124.             return 1;
  125.         }
  126.         return -1;
  127.     }
  128.     if (cast_fp16_to_fixed(x) < cast_fp16_to_fixed(y)) {
  129.         return -1;
  130.     }
  131.     if (cast_fp16_to_fixed(x) == cast_fp16_to_fixed(y)) {
  132.         return 0;
  133.     }
  134.     return 1;
  135. }
  136.  
  137. uint16_t fp16_cast(unsigned int x) {
  138.     return cast_fixed_to_fp16((int64_t)x << CONST24);
  139. }
  140.  
  141. #include <assert.h>
  142. #include <stdint.h>
  143.  
  144. uint16_t fp16_cast(unsigned);
  145. uint16_t fp16_mul2(uint16_t);
  146. uint16_t fp16_div2(uint16_t);
  147. uint16_t fp16_neg(uint16_t);
  148. uint16_t fp16_add(uint16_t, uint16_t);
  149. int fp16_cmp(uint16_t, uint16_t);
  150.  
  151. int main() {
  152.     uint16_t x = fp16_cast(1);
  153.     assert(x == 0b0011110000000000);
  154.     uint16_t y = fp16_cast(2);
  155.     assert(y == 0b0100000000000000);
  156.     // uint64_t z = cast_fp16_to_fixed(y);
  157.     // assert(z == 2);
  158.     assert(fp16_div2(y) == x);
  159.     assert(fp16_mul2(x) == y);
  160.     assert(fp16_cmp(x, y) == -1);
  161.     assert(fp16_cmp(y, x) == 1);
  162.     assert(fp16_cmp(x, x) == 0);
  163.     assert(fp16_cmp(fp16_neg(x), fp16_neg(y)) == 1);
  164.     assert(fp16_cmp(fp16_neg(y), fp16_neg(x)) == -1);
  165.     assert(fp16_cmp(0, fp16_neg(0)) == 0);
  166.  
  167.     uint16_t three = fp16_add(x, y);
  168.     assert(three == 0b0100001000000000);
  169.  
  170.     uint16_t large = fp16_cast((1 << 16) - (1 << 4) - 1);
  171.     uint16_t inf = fp16_mul2(large);
  172.     assert(inf == 0b0111110000000000);
  173.     assert(fp16_mul2(inf) == inf);
  174.     assert(fp16_div2(inf) == inf);
  175.     assert(fp16_cmp(large, inf) == -1);
  176.     assert(fp16_cmp(fp16_neg(inf), large) == -1);
  177.     assert(fp16_add(inf, fp16_neg(inf)) == fp16_add(fp16_neg(inf), inf));
  178.     assert(fp16_add(fp16_neg(large), three) == fp16_neg(large));
  179.     assert(fp16_add(large, fp16_cast(15)) == large);
  180.     assert(fp16_add(large, fp16_cast(16)) == inf);
  181.  
  182.     uint16_t small = 0b0000000000000001;
  183.     assert(fp16_cmp(small, small) == 0);
  184.     assert(fp16_cmp(small, large) == -1);
  185.     assert(fp16_cmp(large, small) == 1);
  186.     assert(fp16_div2(small) == 0);
  187.     assert(fp16_add(fp16_neg(small), x) == x);
  188.     assert(fp16_mul2(small) == small << 1);
  189.  
  190.     uint16_t smallish = small << 9;
  191.     assert(fp16_mul2(smallish) == 0b000001 << 10);
  192.     assert(fp16_div2(fp16_mul2(smallish)) == smallish);
  193.     assert(fp16_div2(smallish) == smallish >> 1);
  194. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement