동일한 디바이저일 때 빠른 AVX512 모듈로
저는 잠재적 요인 소수(n!+-1 형식의 수)에 대한 나눗셈을 찾으려고 노력했고, 최근에 스카이레이크-X 워크스테이션을 구입했기 때문에 AVX512 지침을 사용하여 속도를 높일 수 있을 것이라고 생각했습니다.
알고리즘은 간단하며 주요 단계는 모듈로를 동일한 인수에 대해 반복적으로 적용하는 것입니다.큰 범위의 n개 값을 루프하는 것이 중요합니다.다음은 c(P는 소수의 표)로 작성된 순진한 접근법입니다.
uint64_t factorial_naive(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i, residue;
for (i = 0; i < APP_BUFLEN; i++){
residue = 2;
for (n=3; n <= nmax; n++){
residue *= n;
residue %= P[i];
// Lets check if we found factor
if (nmin <= n){
if( residue == 1){
report_factor(n, -1, P[i]);
}
if(residue == P[i]- 1){
report_factor(n, 1, P[i]);
}
}
}
}
return EXIT_SUCCESS;
}
여기서 아이디어는 동일한 지수 집합에 대해 n의 넓은 범위, 예를 들어 1,000,000 -> 10,000,000을 확인하는 것입니다.그래서 우리는 모듈로를 동일한 지수에 대해 수백만 번 존중할 것입니다.DIV를 사용하는 것은 매우 느리기 때문에 계산 범위에 따라 몇 가지 가능한 방법이 있습니다.여기서 나의 경우 n은 10^7보다 작고 전위 나눗셈 p는 10,000 G(< 10^13)보다 작기 때문에 숫자는 64비트보다 작고 또한 53비트보다 작지만, 최대 잔류물(p-1) 곱하기 n의 곱은 64비트보다 큽니다.그래서 저는 가장 단순한 버전의 몽고메리 방법은 작동하지 않는다고 생각했습니다. 왜냐하면 우리는 64비트보다 큰 숫자에서 모듈로를 추출하기 때문입니다.
저는 FMA가 더블을 사용할 때 106비트까지 정확한 제품을 얻기 위해 사용된 파워 PC의 오래된 코드를 찾았습니다.그래서 저는 이 접근 방식을 AVX 512 어셈블러(Intel Intrusics)로 변환했습니다.여기 FMA 방법의 간단한 버전이 있습니다. 이는 데커(1971)의 작업, 데커 제품 및 TwoProduct의 FMA 버전을 기반으로 합니다. 이는 이에 대한 근거를 찾거나 검색할 때 유용한 단어입니다.또한 이 접근 방식은 이 포럼에서 논의되었습니다(예: 여기).
int64_t factorial_FMA(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i;
double prime_double, prime_double_reciprocal, quotient, residue;
double nr, n_double, prime_times_quotient_high, prime_times_quotient_low;
for (i = 0; i < APP_BUFLEN; i++){
residue = 2.0;
prime_double = (double)P[i];
prime_double_reciprocal = 1.0 / prime_double;
n_double = 3.0;
for (n=3; n <= nmax; n++){
nr = n_double * residue;
quotient = fma(nr, prime_double_reciprocal, rounding_constant);
quotient -= rounding_constant;
prime_times_quotient_high= prime_double * quotient;
prime_times_quotient_low = fma(prime_double, quotient, -prime_times_quotient_high);
residue = fma(residue, n, -prime_times_quotient_high) - prime_times_quotient_low;
if (residue < 0.0) residue += prime_double;
n_double += 1.0;
// Lets check if we found factor
if (nmin <= n){
if( residue == 1.0){
report_factor(n, -1, P[i]);
}
if(residue == prime_double - 1.0){
report_factor(n, 1, P[i]);
}
}
}
}
return EXIT_SUCCESS;
}
여기서 나는 마법 상수를 사용했습니다.
static const double rounding_constant = 6755399441055744.0;
그것은 복식에 대한 2^51 + 2^52 마법 번호입니다.
저는 이것을 AVX512(루프당 32개의 잠재적인 지수)로 변환하고 IACA를 사용하여 결과를 분석했습니다.처리량 병목 현상: 백엔드 및 백엔드 할당이 할당 리소스를 사용할 수 없어 지연되었다고 합니다.저는 조립자에 대한 경험이 별로 없습니다. 제 질문은 제가 이 작업을 가속화하고 백엔드 병목 현상을 해결하기 위해 할 수 있는 일이 있습니까?
AVX512 코드는 여기에 있으며 github에서도 찾을 수 있습니다.
uint64_t factorial_AVX512_unrolled_four(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
// we are trying to find a factor for a factorial numbers : n! +-1
//nmin is minimum n we want to report and nmax is maximum. P is table of primes
// we process 32 primes in one loop.
// naive version of the algorithm is int he function factorial_naive
// and simple version of the FMA based approach in the function factorial_simpleFMA
const double one_table[8] __attribute__ ((aligned(64))) ={1.0, 1.0, 1.0,1.0,1.0,1.0,1.0,1.0};
uint64_t n;
__m512d zero, rounding_const, one, n_double;
__m512i prime1, prime2, prime3, prime4;
__m512d residue1, residue2, residue3, residue4;
__m512d prime_double_reciprocal1, prime_double_reciprocal2, prime_double_reciprocal3, prime_double_reciprocal4;
__m512d quotient1, quotient2, quotient3, quotient4;
__m512d prime_times_quotient_high1, prime_times_quotient_high2, prime_times_quotient_high3, prime_times_quotient_high4;
__m512d prime_times_quotient_low1, prime_times_quotient_low2, prime_times_quotient_low3, prime_times_quotient_low4;
__m512d nr1, nr2, nr3, nr4;
__m512d prime_double1, prime_double2, prime_double3, prime_double4;
__m512d prime_minus_one1, prime_minus_one2, prime_minus_one3, prime_minus_one4;
__mmask8 negative_reminder_mask1, negative_reminder_mask2, negative_reminder_mask3, negative_reminder_mask4;
__mmask8 found_factor_mask11, found_factor_mask12, found_factor_mask13, found_factor_mask14;
__mmask8 found_factor_mask21, found_factor_mask22, found_factor_mask23, found_factor_mask24;
// load data and initialize cariables for loop
rounding_const = _mm512_set1_pd(rounding_constant);
one = _mm512_load_pd(one_table);
zero = _mm512_setzero_pd ();
// load primes used to sieve
prime1 = _mm512_load_epi64((__m512i *) &P[0]);
prime2 = _mm512_load_epi64((__m512i *) &P[8]);
prime3 = _mm512_load_epi64((__m512i *) &P[16]);
prime4 = _mm512_load_epi64((__m512i *) &P[24]);
// convert primes to double
prime_double1 = _mm512_cvtepi64_pd (prime1); // vcvtqq2pd
prime_double2 = _mm512_cvtepi64_pd (prime2); // vcvtqq2pd
prime_double3 = _mm512_cvtepi64_pd (prime3); // vcvtqq2pd
prime_double4 = _mm512_cvtepi64_pd (prime4); // vcvtqq2pd
// calculates 1.0/ prime
prime_double_reciprocal1 = _mm512_div_pd(one, prime_double1);
prime_double_reciprocal2 = _mm512_div_pd(one, prime_double2);
prime_double_reciprocal3 = _mm512_div_pd(one, prime_double3);
prime_double_reciprocal4 = _mm512_div_pd(one, prime_double4);
// for comparison if we have found factors for n!+1
prime_minus_one1 = _mm512_sub_pd(prime_double1, one);
prime_minus_one2 = _mm512_sub_pd(prime_double2, one);
prime_minus_one3 = _mm512_sub_pd(prime_double3, one);
prime_minus_one4 = _mm512_sub_pd(prime_double4, one);
// residue init
residue1 = _mm512_set1_pd(2.0);
residue2 = _mm512_set1_pd(2.0);
residue3 = _mm512_set1_pd(2.0);
residue4 = _mm512_set1_pd(2.0);
// double counter init
n_double = _mm512_set1_pd(3.0);
// main loop starts here. typical value for nmax can be 5,000,000 -> 10,000,000
for (n=3; n<=nmax; n++) // main loop
{
// timings for instructions:
// _mm512_load_epi64 = vmovdqa64 : L 1, T 0.5
// _mm512_load_pd = vmovapd : L 1, T 0.5
// _mm512_set1_pd
// _mm512_div_pd = vdivpd : L 23, T 16
// _mm512_cvtepi64_pd = vcvtqq2pd : L 4, T 0,5
// _mm512_mul_pd = vmulpd : L 4, T 0.5
// _mm512_fmadd_pd = vfmadd132pd, vfmadd213pd, vfmadd231pd : L 4, T 0.5
// _mm512_fmsub_pd = vfmsub132pd, vfmsub213pd, vfmsub231pd : L 4, T 0.5
// _mm512_sub_pd = vsubpd : L 4, T 0.5
// _mm512_cmplt_pd_mask = vcmppd : L ?, Y 1
// _mm512_mask_add_pd = vaddpd : L 4, T 0.5
// _mm512_cmpeq_pd_mask = vcmppd L ?, Y 1
// _mm512_kor = korw L 1, T 1
// nr = residue * n
nr1 = _mm512_mul_pd (residue1, n_double);
nr2 = _mm512_mul_pd (residue2, n_double);
nr3 = _mm512_mul_pd (residue3, n_double);
nr4 = _mm512_mul_pd (residue4, n_double);
// quotient = nr * 1.0/ prime_double + rounding_constant
quotient1 = _mm512_fmadd_pd(nr1, prime_double_reciprocal1, rounding_const);
quotient2 = _mm512_fmadd_pd(nr2, prime_double_reciprocal2, rounding_const);
quotient3 = _mm512_fmadd_pd(nr3, prime_double_reciprocal3, rounding_const);
quotient4 = _mm512_fmadd_pd(nr4, prime_double_reciprocal4, rounding_const);
// quotient -= rounding_constant, now quotient is rounded to integer
// countient should be at maximum nmax (10,000,000)
quotient1 = _mm512_sub_pd(quotient1, rounding_const);
quotient2 = _mm512_sub_pd(quotient2, rounding_const);
quotient3 = _mm512_sub_pd(quotient3, rounding_const);
quotient4 = _mm512_sub_pd(quotient4, rounding_const);
// now we calculate high and low for prime * quotient using decker product (FMA).
// quotient is calculated using approximation but this is accurate for given quotient
prime_times_quotient_high1 = _mm512_mul_pd(quotient1, prime_double1);
prime_times_quotient_high2 = _mm512_mul_pd(quotient2, prime_double2);
prime_times_quotient_high3 = _mm512_mul_pd(quotient3, prime_double3);
prime_times_quotient_high4 = _mm512_mul_pd(quotient4, prime_double4);
prime_times_quotient_low1 = _mm512_fmsub_pd(quotient1, prime_double1, prime_times_quotient_high1);
prime_times_quotient_low2 = _mm512_fmsub_pd(quotient2, prime_double2, prime_times_quotient_high2);
prime_times_quotient_low3 = _mm512_fmsub_pd(quotient3, prime_double3, prime_times_quotient_high3);
prime_times_quotient_low4 = _mm512_fmsub_pd(quotient4, prime_double4, prime_times_quotient_high4);
// now we calculate new reminder using decker product and using original values
// we subtract above calculated prime * quotient (quotient is aproximation)
residue1 = _mm512_fmsub_pd(residue1, n_double, prime_times_quotient_high1);
residue2 = _mm512_fmsub_pd(residue2, n_double, prime_times_quotient_high2);
residue3 = _mm512_fmsub_pd(residue3, n_double, prime_times_quotient_high3);
residue4 = _mm512_fmsub_pd(residue4, n_double, prime_times_quotient_high4);
residue1 = _mm512_sub_pd(residue1, prime_times_quotient_low1);
residue2 = _mm512_sub_pd(residue2, prime_times_quotient_low2);
residue3 = _mm512_sub_pd(residue3, prime_times_quotient_low3);
residue4 = _mm512_sub_pd(residue4, prime_times_quotient_low4);
// lets check if reminder < 0
negative_reminder_mask1 = _mm512_cmplt_pd_mask(residue1,zero);
negative_reminder_mask2 = _mm512_cmplt_pd_mask(residue2,zero);
negative_reminder_mask3 = _mm512_cmplt_pd_mask(residue3,zero);
negative_reminder_mask4 = _mm512_cmplt_pd_mask(residue4,zero);
// we and prime back to reminder using mask if it was < 0
residue1 = _mm512_mask_add_pd(residue1, negative_reminder_mask1, residue1, prime_double1);
residue2 = _mm512_mask_add_pd(residue2, negative_reminder_mask2, residue2, prime_double2);
residue3 = _mm512_mask_add_pd(residue3, negative_reminder_mask3, residue3, prime_double3);
residue4 = _mm512_mask_add_pd(residue4, negative_reminder_mask4, residue4, prime_double4);
n_double = _mm512_add_pd(n_double,one);
// if we are below nmin then we continue next iteration
if (n < nmin) continue;
// Lets check if we found any factors, residue 1 == n!-1
found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);
// residue prime -1 == n!+1
found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);
if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
{ // we find factor very rarely
double *residual_list1 = (double *) &residue1;
double *residual_list2 = (double *) &residue2;
double *residual_list3 = (double *) &residue3;
double *residual_list4 = (double *) &residue4;
double *prime_list1 = (double *) &prime_double1;
double *prime_list2 = (double *) &prime_double2;
double *prime_list3 = (double *) &prime_double3;
double *prime_list4 = (double *) &prime_double4;
for (int i=0; i <8; i++){
if( residual_list1[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list1[i]);
}
if( residual_list2[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list2[i]);
}
if( residual_list3[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list3[i]);
}
if( residual_list4[i] == 1.0)
{
report_factor((uint64_t) n, -1, (uint64_t) prime_list4[i]);
}
if(residual_list1[i] == (prime_list1[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list1[i]);
}
if(residual_list2[i] == (prime_list2[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list2[i]);
}
if(residual_list3[i] == (prime_list3[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list3[i]);
}
if(residual_list4[i] == (prime_list4[i] - 1.0))
{
report_factor((uint64_t) n, 1, (uint64_t) prime_list4[i]);
}
}
}
}
return EXIT_SUCCESS;
}
몇 명의 논평가가 제안했듯이, "백엔드" 병목 현상은 이 코드에 대해 기대할 수 있는 것입니다.그것은 당신이 원하는 것을 꽤 잘 먹이고 있다는 것을 암시합니다.
보고서를 보면 이 섹션에 기회가 있을 것입니다.
// Lets check if we found any factors, residue 1 == n!-1
found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);
// residue prime -1 == n!+1
found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);
if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
IACA 분석 결과:
| 1 | 1.0 | | | | | | | | kmovw r11d, k0
| 1 | 1.0 | | | | | | | | kmovw eax, k1
| 1 | 1.0 | | | | | | | | kmovw ecx, k2
| 1 | 1.0 | | | | | | | | kmovw esi, k3
| 1 | 1.0 | | | | | | | | kmovw edi, k4
| 1 | 1.0 | | | | | | | | kmovw r8d, k5
| 1 | 1.0 | | | | | | | | kmovw r9d, k6
| 1 | 1.0 | | | | | | | | kmovw r10d, k7
| 1 | | 1.0 | | | | | | | or r11d, eax
| 1 | | | | | | | 1.0 | | or r11d, ecx
| 1 | | 1.0 | | | | | | | or r11d, esi
| 1 | | | | | | | 1.0 | | or r11d, edi
| 1 | | 1.0 | | | | | | | or r11d, r8d
| 1 | | | | | | | 1.0 | | or r11d, r9d
| 1* | | | | | | | | | or r11d, r10d
프로세서가 "또는" 작업을 위해 결과 비교 마스크(k0-k7)를 일반 레지스터로 이동합니다.이러한 움직임을 제거하고 "또는" 롤업을 8개가 아닌 6개의 작업으로 수행할 수 있어야 합니다.
은 "found_factor_mask"로 됩니다.__mmask8
이 할 __mask16
(512비트 이펙터에서 16x 이중 부동).그러면 컴파일러가 몇 가지 최적화를 수행할 수 있습니다.그렇지 않은 경우 주석자가 언급한 대로 어셈블리로 이동합니다.
관련: 이 or-mask 조항을 발사하는 상호작용의 비율은 얼마입니까?다른 의견제출자가 관찰한 바와 같이 누적 "또는" 작업으로 이를 롤아웃할 수 있습니다.각 롤백된 반복이 끝날 때(또는 N번 반복 후) 누적된 "or" 값을 확인하고, "true"이면 다시 돌아가서 값을 다시 수행하여 n번 값이 트리거되었는지 확인합니다.
(그리고 "roll" 내에서 이진 검색을 통해 일치하는 n 값을 찾을 수 있습니다. 이 값은 약간의 이득을 얻을 수 있습니다.)
다음으로, 이 중간 루프 검사를 제거할 수 있습니다.
// if we are below nmin then we continue next iteration, we
if (n < nmin) continue;
다음과 같이 표시됩니다.
| 1* | | | | | | | | | cmp r14, 0x3e8
| 0*F | | | | | | | | | jb 0x229
예측 변수가 이 값을 (아마도) 맞출 것이기 때문에 큰 이득은 아닐 수 있지만, 두 "단계"에 대해 두 개의 뚜렷한 루프를 가짐으로써 어느 정도 이득을 얻을 수 있습니다.
- n=3 ~ n=nmin-1
- n=nmin 이상
사이클이 증가한다고 해도 3%입니다.그리고 그것은 일반적으로 큰 '또는' 작업과 관련이 있기 때문에, 위에서, 그 안에 더 많은 영리함이 발견될 수 있습니다.
언급URL : https://stackoverflow.com/questions/47855357/fast-avx512-modulo-when-same-divisor
'programing' 카테고리의 다른 글
이미지를 사용하는 것보다 서클 디브를 만드는 더 쉬운 방법은? (0) | 2023.09.04 |
---|---|
PowerShell 버전 정렬 (0) | 2023.09.04 |
PowerShell에서 해시 테이블 병합: 어떻게? (0) | 2023.09.04 |
판다 데이터 프레임에서 사용하는 메모리를 해제하려면 어떻게 해야 합니까? (0) | 2023.09.04 |
AJAX를 통해 제출하려고 할 때 %5Bobject%20Object%5D(404를 찾을 수 없음) (0) | 2023.09.04 |