使用c语言实现了RSA加解密算法,可以加解密文件和字符串。
rsa算法原理
- 选择两个大素数p和q;
- 计算n = p * q;
- 计算φ(n)=(p-1)(q-1);
- 选择与φ(n)互素的整数d;
- 由de=1 mod φ(n)计算得到e;
- 公钥是(e, n), 私钥是(d, n);
- 假设明文是M(一个整数),则密文C = mod n,此为加密过程;
- 解密过程为M = mod n;
rsa.h
#pragma once
#include <iostream>
#include <fstream>
#include <tommath.h>
#include <time.h>
#include <Windows.h>
#include <string>
#define SUBKEY_LENGTH 78 // >512 bit
#define FILE_NAME_LENGTH 99
#define PLAINTEXT_LENGTH 64
#define BINARY_LENGTH 512
using namespace std;
void initial();
int Get_char_bit(char c, int pos);
void Create_prime_number(mp_int *number);
void Generate_key(char *key_name);
void Write_key_2_File(char *file_name, mp_int *key, mp_int *n);
void ADD_0(char *binary, int mode);
void rsa_decrypt(char *src, char *dst, char *key_name, int mode);
void rsa_encrypt(char *src, char *dst, char *key_name, int mode);
void mp_print(mp_int *number);
void Get_key_from_file(mp_int *key, mp_int *n, char *key_name, int mode);
void Create_number(mp_int *number, int mode);
int Miller_rabin(mp_int *number);
void Char_2_binary(char *text, char *binary, int len);
void Binary_2_char(char *binary, char *text, int binary_len);
void quick_pow(mp_int *a, mp_int *b, mp_int *c, mp_int *d);
int Get_file_length(char *filename);
rsa.cpp
#include "rsa.h"
mp_int two;
mp_int five;
mp_int zero;
mp_int one;
int Get_char_bit(char c, int pos)
{
return ((c >> (7 - pos)) & 1);
}
// mode 0: number of n digits
// mode 1: 2 <= number <= n - 2 for miller rabin test
void Create_number(mp_int *number, int mode)
{
int i;
srand((unsigned)time(NULL));
char temp_number[SUBKEY_LENGTH + 1];
temp_number[0] = rand() % 9 + 1 + 48;
if (0 == mode)
{
int temp;
for (i = 1; i <= SUBKEY_LENGTH - 2; i++)
temp_number[i] = rand() % 10 + 48;
temp = rand() % 10;
if (0 == temp % 2)
temp++;
if (5 == temp)
temp = 7;
temp_number[SUBKEY_LENGTH - 1] = temp + 48;
temp_number[SUBKEY_LENGTH] = '\0';
}
else if (1 == mode)
{
int digit = rand() % (SUBKEY_LENGTH - 2) + 2;
for (i = 1; i <= digit - 1; i++)
temp_number[i] = rand() % 10 + 48;
temp_number[digit] = '\0';
}
mp_read_radix(number, temp_number, 10);
}
int Miller_rabin(mp_int *number)
{
int result;
mp_int base;
mp_init_size(&base, SUBKEY_LENGTH);
Create_number(&base, 1);
mp_prime_miller_rabin(number, &base, &result);
mp_clear(&base);
return result;
}
void Create_prime_number(mp_int *number)
{
mp_int r;
mp_init(&r);
int time = 100;
int result;
int i;
Create_number(number, 0);
while (1)
{
mp_prime_is_divisible(number, &result);
if (0 != result)
{
do
{
mp_add(number, &two, number);
mp_mod(number, &five, &r);
} while (MP_EQ == mp_cmp_mag(&zero, &r));
continue;
}
for (i = 0; i < time; i++)
{
if (!Miller_rabin(number))
break;
}
if (i == time)
return;
else
{
do
{
mp_add(number, &two, number);
mp_mod(number, &five, &r);
} while (MP_EQ == mp_cmp_mag(&zero, &r));
}
}
}
void Generate_key(char *key_name)
{
mp_int p, q, n, f, e, d;
mp_init_size(&p, SUBKEY_LENGTH);
mp_init_size(&q, SUBKEY_LENGTH);
mp_init_size(&n, SUBKEY_LENGTH * 2);
mp_init_size(&f, SUBKEY_LENGTH * 2);
mp_init_size(&d, SUBKEY_LENGTH * 2);
mp_init_size(&e, SUBKEY_LENGTH * 2);
Create_prime_number(&p);
Sleep(1000);
Create_prime_number(&q);
mp_mul(&p, &q, &n);
mp_sub(&p, &one, &p);
mp_sub(&q, &one, &q);
mp_mul(&p, &q, &f);
mp_set_int(&d, 127);
mp_int gcd;
mp_init_size(&gcd, SUBKEY_LENGTH * 2);
do
{
mp_add(&d, &two, &d);
mp_gcd(&d, &f, &gcd);
} while (MP_EQ != mp_cmp_mag(&gcd, &one));
mp_invmod(&d, &f, &e);
char PUBLIC_KEY[FILE_NAME_LENGTH] = "d:\\public_key_";
char PRIVATE_KEY[FILE_NAME_LENGTH] = "d:\\private_key_";
strcat(PUBLIC_KEY, key_name);
strcat(PRIVATE_KEY, key_name);
Write_key_2_File(PUBLIC_KEY, &d, &n);
Write_key_2_File(PRIVATE_KEY, &e, &n);
mp_clear_multi(&p, &q, &e, &d, &f, &n, &gcd, NULL);
}
void initial()
{
mp_init_set_int(&two, 2);
mp_init_set_int(&five, 5);
mp_init_set_int(&zero, 0);
mp_init_set_int(&one, 1);
}
void Write_key_2_File(char *file_name, mp_int *key, mp_int *n)
{
remove(file_name);
FILE *fp = fopen(file_name, "w+");
if (NULL == fp)
{
cout << "open file error!" << endl;
return;
}
char key_str[SUBKEY_LENGTH * 2 + 1];
char n_str[SUBKEY_LENGTH * 2 + 1];
mp_toradix(key, key_str, 10);
mp_toradix(n, n_str, 10);
fprintf(fp, "%s\n", n_str);
fprintf(fp, "%s", key_str);
fclose(fp);
}
void mp_print(mp_int *number)
{
char str[SUBKEY_LENGTH * 2 + 1];
mp_toradix(number, str, 10);
cout << str << endl;
}
// mode 0: get public key
// mode 1: get private key
void Get_key_from_file(mp_int *key, mp_int *n, char *key_name, int mode)
{
FILE *fp = NULL;
char PUBLIC_KEY[FILE_NAME_LENGTH] = "";
char PRIVATE_KEY[FILE_NAME_LENGTH] = "";
if (0 == mode)
{
strcat(PUBLIC_KEY, key_name);
fp = fopen(PUBLIC_KEY, "r+");
}
else if (1 == mode)
{
strcat(PRIVATE_KEY, key_name);
fp = fopen(PRIVATE_KEY, "r+");
}
if (NULL == fp)
{
cout << "open file error!" << endl;
return;
}
char key_str[SUBKEY_LENGTH * 2 + 1];
char n_str[SUBKEY_LENGTH * 2 + 1];
fscanf(fp, "%s\n", n_str);
fscanf(fp, "%s", key_str);
mp_read_radix(n, n_str, 10);
mp_read_radix(key, key_str, 10);
fclose(fp);
}
void Char_2_binary(char *text, char *binary, int len)
{
int i;
int j;
int k = 0;
for (i = 0; i <= len - 1; i++)
{
for (j = 0; j <= 7; j++)
binary[k++] = Get_char_bit(text[i], j) + '0';
}
binary[k] = '\0';
}
void Binary_2_char(char *binary, char *text, int binary_len)
{
int i;
int j = 0;
int k = 0;
int sum = 0;
for (i = 0; i <= binary_len - 1; i++)
{
sum = sum + pow(2, 7 - j)*(int)((binary[i]) - '0');
j++;
if (8 == j)
{
text[k++] = sum;
sum = 0;
j = 0;
}
}
}
// mode 0: encrypt char*
// mode 1: encrypt file
void rsa_encrypt(char *src, char *dst, char *key_name, int mode)
{
mp_int n;
mp_int public_key;
mp_int plain_number;
mp_int cipher_number;
mp_init_size(&n, SUBKEY_LENGTH * 2);
mp_init_size(&public_key, SUBKEY_LENGTH * 2);
mp_init_size(&plain_number, SUBKEY_LENGTH * 2);
mp_init_size(&cipher_number, SUBKEY_LENGTH * 2);
Get_key_from_file(&public_key, &n, key_name, 0);
char plain_text[PLAINTEXT_LENGTH + 1];
char cipher_text[PLAINTEXT_LENGTH + 2];
char plain_binary[BINARY_LENGTH + 1];
char cipher_binary[BINARY_LENGTH + 9];
int i;
int j;
int k;
int l;
if (0 == mode)
{
int round = ceil(strlen(src) / 64.0);
for (i = 0; i < round; i++)
{
for (l = 0; l <= BINARY_LENGTH; l++)
plain_binary[l] = '\0';
for (l = 0; l <= BINARY_LENGTH + 8; l++)
cipher_binary[l] = '\0';
for (l = 0; l <= PLAINTEXT_LENGTH; l++)
plain_text[l] = '\0';
for (l = 0; l <= PLAINTEXT_LENGTH + 1; l++)
cipher_text[l] = '\0';
for (j = 0; j <= PLAINTEXT_LENGTH - 1; j++)
{
plain_text[j] = src[i * PLAINTEXT_LENGTH + j];
if ('\0' == plain_text[j])
{
for (k = j + 1; k <= PLAINTEXT_LENGTH - 1; k++)
plain_text[k] = '\0';
break;
}
}
plain_text[PLAINTEXT_LENGTH] = '\0';
Char_2_binary(plain_text, plain_binary, PLAINTEXT_LENGTH);
mp_zero(&plain_number);
mp_read_radix(&plain_number, plain_binary, 2);
mp_zero(&cipher_number);
// ----------------------------------------------------------------
mp_exptmod(&plain_number, &public_key, &n, &cipher_number);
//quick_pow(&plain_number, &public_key, &n, &cipher_number);
// ----------------------------------------------------------------
mp_toradix(&cipher_number, cipher_binary, 2);
ADD_0(cipher_binary, 0);
Binary_2_char(cipher_binary, cipher_text, BINARY_LENGTH + 8);
for (l = 0; l <= PLAINTEXT_LENGTH; l++)
dst[i*(PLAINTEXT_LENGTH + 1) + l] = cipher_text[l];
}
dst[round * (PLAINTEXT_LENGTH + 1)] = '\0';
mp_clear_multi(&n, &public_key, &plain_number, &cipher_number, NULL);
}
else if (1 == mode)
{
ifstream fin;
ofstream fout;
fin.open(src, ios::binary);
fout.open(dst, ios::binary);
char ch;
int i = 0;
int len = Get_file_length(src) % 64;
fout << len << endl;
while (1)
{
fin.get(ch);
plain_text[i++] = ch;
if (i == PLAINTEXT_LENGTH || fin.eof())
{
if (i == PLAINTEXT_LENGTH)
plain_text[i] = '\0';
else
{
for (k = i - 1; k <= PLAINTEXT_LENGTH; k++)
plain_text[k] = '\0';
}
for (l = 0; l <= BINARY_LENGTH; l++)
plain_binary[l] = '\0';
for (l = 0; l <= BINARY_LENGTH + 8; l++)
cipher_binary[l] = '\0';
for (l = 0; l <= PLAINTEXT_LENGTH + 1; l++)
cipher_text[l] = '\0';
Char_2_binary(plain_text, plain_binary, PLAINTEXT_LENGTH);
mp_zero(&plain_number);
mp_read_radix(&plain_number, plain_binary, 2);
mp_zero(&cipher_number);
// ----------------------------------------------------------------
mp_exptmod(&plain_number, &public_key, &n, &cipher_number);
//quick_pow(&plain_number, &public_key, &n, &cipher_number);
// ----------------------------------------------------------------
mp_toradix(&cipher_number, cipher_binary, 2);
ADD_0(cipher_binary, 0);
Binary_2_char(cipher_binary, cipher_text, BINARY_LENGTH + 8);
for (l = 0; l <= PLAINTEXT_LENGTH; l++)
fout << cipher_text[l];
for (l = 0; l <= PLAINTEXT_LENGTH; l++)
plain_text[l] = '\0';
i = 0;
}
if (fin.eof())
break;
}
fin.close();
fout.close();
mp_clear_multi(&n, &public_key, &plain_number, &cipher_number, NULL);
}
}
// mode 0: for encrypt
// mode 1: for decrypt
void ADD_0(char *binary, int mode)
{
int i;
int difference;
if (0 == mode)
difference = BINARY_LENGTH + 8 - strlen(binary);
else if(1 == mode)
difference = BINARY_LENGTH - strlen(binary);;
for (i = strlen(binary); i >= 0; i--)
binary[i + difference] = binary[i];
for (i = 0; i < difference; i++)
binary[i] = '0';
}
// mode 0: decrypt char*
// mode 1: decrypt file
void rsa_decrypt(char *src, char *dst, char *key_name, int mode)
{
mp_int n;
mp_int private_key;
mp_int plain_number;
mp_int cipher_number;
mp_init_size(&n, SUBKEY_LENGTH * 2);
mp_init_size(&private_key, SUBKEY_LENGTH * 2);
mp_init_size(&plain_number, SUBKEY_LENGTH * 2);
mp_init_size(&cipher_number, SUBKEY_LENGTH * 2);
Get_key_from_file(&private_key, &n, key_name, 1);
char plain_text[PLAINTEXT_LENGTH + 1];
char cipher_text[PLAINTEXT_LENGTH + 2];
char plain_binary[BINARY_LENGTH + 1];
char cipher_binary[BINARY_LENGTH + 9];
int j;
int k;
int l;
if (0 == mode)
{
k = -1;
do
{
k++;
for (l = 0; l <= PLAINTEXT_LENGTH; l++)
plain_text[l] = '\0';
for (l = 0; l <= BINARY_LENGTH; l++)
plain_binary[l] = '\0';
for (l = 0; l <= BINARY_LENGTH + 8; l++)
cipher_binary[l] = '\0';
for (j = 0; j <= PLAINTEXT_LENGTH; j++)
cipher_text[j] = src[k * (PLAINTEXT_LENGTH + 1) + j];
cipher_text[PLAINTEXT_LENGTH + 1] = '\0';
Char_2_binary(cipher_text, cipher_binary, PLAINTEXT_LENGTH + 1);
mp_zero(&cipher_number);
mp_read_radix(&cipher_number, cipher_binary, 2);
mp_zero(&plain_number);
// ----------------------------------------------------------------
mp_exptmod(&cipher_number, &private_key, &n, &plain_number);
//quick_pow(&cipher_number, &private_key, &n, &plain_number);
// ----------------------------------------------------------------
mp_toradix(&plain_number, plain_binary, 2);
ADD_0(plain_binary, 1);
Binary_2_char(plain_binary, plain_text, BINARY_LENGTH);
for (l = 0; l <= PLAINTEXT_LENGTH - 1; l++)
dst[k*PLAINTEXT_LENGTH + l] = plain_text[l];
} while ('\0' != src[(k + 1)*(PLAINTEXT_LENGTH + 1)]);
dst[(k + 1) * PLAINTEXT_LENGTH] = '\0';
}
else if (1 == mode)
{
ifstream fin;
ofstream fout;
fin.open(src, ios::binary);
fout.open(dst, ios::binary);
int i = 0;
int limit = PLAINTEXT_LENGTH;
char ch;
string temp = "";
while (1)
{
fin.get(ch);
if ('\n' == ch)
break;
temp += ch;
}
int len = atoi(temp.c_str());
while (1)
{
fin.get(ch);
cipher_text[i++] = ch;
if (i == PLAINTEXT_LENGTH + 1)
{
if (i == PLAINTEXT_LENGTH + 1)
cipher_text[i] = '\0';
else
cipher_text[i - 1] = '\0';
for (l = 0; l <= BINARY_LENGTH + 8; l++)
cipher_binary[l] = '\0';
for (l = 0; l <= PLAINTEXT_LENGTH; l++)
plain_text[l] = '\0';
for (l = 0; l <= BINARY_LENGTH; l++)
plain_binary[l] = '\0';
Char_2_binary(cipher_text, cipher_binary, PLAINTEXT_LENGTH + 1);
mp_zero(&cipher_number);
mp_read_radix(&cipher_number, cipher_binary, 2);
mp_zero(&plain_number);
// ----------------------------------------------------------------
mp_exptmod(&cipher_number, &private_key, &n, &plain_number);
//quick_pow(&cipher_number, &private_key, &n, &plain_number);
// ----------------------------------------------------------------
mp_toradix(&plain_number, plain_binary, 2);
ADD_0(plain_binary, 1);
Binary_2_char(plain_binary, plain_text, BINARY_LENGTH);
if (fin.peek() == EOF)
limit = len;
for (l = 0; l < limit; l++)
fout << plain_text[l];
for (l = 0; l <= PLAINTEXT_LENGTH + 1; l++)
cipher_text[l] = '\0';
i = 0;
}
if (fin.eof())
break;
}
fin.close();
fout.close();
}
mp_clear_multi(&n, &private_key, &plain_number, &cipher_number, NULL);
}
// compute d = a ^ b (mod c)
void quick_pow(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
{
mp_int temp;
mp_int temp_a;
mp_int temp_b;
mp_init_size(&temp, SUBKEY_LENGTH * 2);
mp_init_size(&temp_a, SUBKEY_LENGTH * 2);
mp_init_size(&temp_b, SUBKEY_LENGTH * 2);
mp_copy(a, &temp_a);
mp_copy(b, &temp_b);
mp_set_int(d, 1);
mp_mod(&temp_a, c, &temp_a);
while (MP_GT == mp_cmp(&temp_b, &zero))
{
mp_mod(&temp_b, &two, &temp);
if (MP_EQ == mp_cmp(&one, &temp))
{
mp_mul(d, &temp_a, d);
mp_mod(d, c, d);
}
mp_div_2(&temp_b, &temp_b);
mp_mul(&temp_a, &temp_a, &temp_a);
mp_mod(&temp_a, c, &temp_a);
}
mp_clear_multi(&temp, &temp_a, &temp_b, NULL);
}
int Get_file_length(char *filename)
{
FILE *fp = fopen(filename, "rb");
if (NULL == fp)
return -1;
fseek(fp, 0, SEEK_END);
int temp = ftell(fp);
fclose(fp);
return temp;
}
main.h
#pragma once
#include "rsa.h"
main.cpp
#include "main.h"
// example
void main()
{
initial();
// generate key belong to arg " "
Generate_key("hello_world");
// ----------------------example: encrypt char* ---------------------------
char a[200] = "to_be_or_not_to_be_it_is_a_question";
char b[400];
char c[200];
//rsa_encrypt(a, b, "hello_world", 0);
//rsa_decrypt(b, c, "hello_world", 0);
cout << "c: " << c << endl;
// ------------------------------------------------------------------------
// ----------------------example: encrypt file ----------------------------
rsa_encrypt("d:\\a.gif", "d:\\m", "d:\\public_key_hello_world", 1);
rsa_decrypt("d:\\m", "D:\\b.gif", "d:\\private_key_hello_world", 1);
// ------------------------------------------------------------------------
}