본문 바로가기

IT 이야기

nCr % p 소스코드

nCr % p 소스코드입니다.

p가 소수일 때, 작동하며 전역으로 선언해둔 p 값을 변경하시면 됩니다.

그냥 직감으로 짠거라 최적화가 덜 되거나, 돌아가지 않는 조건이 있을 수 있습니다.

입력은 n과 r값을 받으며, n의 최댓값은 10000입니다.

#include <stdio.h>
#define ret return
#define uLL unsigned long long

uLL p = 100000007;

uLL mem[10001];

uLL pm(uLL a, uLL b, uLL n)
{
   uLL am = a % n;
   uLL bm = b % n;

   if (am == 0) ret bm;
   if (bm == 0) ret am;

   if (am + bm <= am) ret(am - (n - bm)) % n;


   ret(am + bm) % n;
}

uLL mm(uLL a, uLL b, uLL n)
{
   uLL am = a % n;
   uLL bm = b % n;

   if (am == 0 || bm == 0) ret 0;
   if (am == 1) ret bm;
   if (bm == 1) ret am;


   uLL asquared = mm(am, bm >> 1, n);
   if ((bm & 1) == 0) ret pm(asquared, asquared, n);

   ret pm(am, pm(asquared, asquared, n), n);
}


uLL facmod(uLL a, uLL b, uLL n, uLL cnt) // a! mod n, b = 1, cnt = 1
{
   mem[cnt] = b;
   if (a == cnt) ret b;
   cnt++;
//   ret facmod(a, cnt*b%n, n, cnt);
   ret facmod(a, mm(cnt, b, n), n, cnt);
}

uLL fast_mod(uLL base, uLL exp)
{
   uLL r = 1, sq = exp;
   while (sq)
   {
   //   printf("%lld\n", sq);
      if (sq & 1) r = mm(r, base, p);


      sq >>= 1;
      base = mm(base, base, p);
   }
   ret r;
}






uLL ncr(uLL n, uLL r)
{
   uLL big_a, big_b, exp_b;
   big_a = facmod(n, 1, p, 1);
//   big_b = mem[r] * mem[n - r] % p;
   big_b = mm(mem[r], mem[n - r], p);
   exp_b = fast_mod(big_b, p - 2);
//   printf("%lld %lld %lld\n", big_a, big_b, exp_b);
   ret mm(big_a, exp_b, p);
}



int main()
{
   uLL n, r;
   scanf("%lld %lld", &n, &r);
   uLL tmp;
   if (r == 0 || n == r) tmp = 1;
   else tmp = ncr(n, r);
   printf("%lld", tmp);
   ret 0;
}

참고한 곳:

https://helloworldpark.github.io/programming/2017/03/14/Modulo_No_Overflow.html

https://cru6548.tistory.com/23