/*
CrossCert.c
This file contains code developed by Gemini Security Solutions, Inc.
It is provided without warranty, nor any agreement to provide technical 
support.  Use this code at your own risk!

You are free to incorporate this code into your own applications as you see fit.
If you discover any bugs, please contact wturnes@geminisecurity.com.

Contect Information:
http://www.geminisecurity.com/
wturnes@geminisecurity.com
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Caps the output filename length at 64 characters.  This can be altered to
fit the constraints of your filesystem */
#define OUTPUT_FNAME_LEN	64

/*

Function Prototypes

*/

/* read_file
Reads the given file in binary mode into the designated buffer.  The buffer is
allocated inside this function using malloc, so the calling function must call 
free on *buf if the call succeeds.  The length of the buffer is written to 
*length.  As such, *buf and length may not be null pointers, or the function 
will have no effect.  

This function returns 0 on failure, and non-zero on success.  Error messages
are written to standard out.

An example for how to use this function follows:

unsigned char *file_buffer = NULL;
unsigned int   file_length = 0;
if( read_file("my.file", &file_buffer, &file_length) ) {
	// do something with the file contents
	free( file_buffer );
	file_buffer = NULL;
	file_length = 0;
}
*/
int read_file( const char* fname, unsigned char **buf, unsigned int *length );

/* get_length_info
Retrieves the length octets information for a given position in the buffer. 
This function assumes that 'loc' is pointing to a valid position for length
encoding.  The length of the DER encoding length is returned in the *out_len_len
parameter.  The length specified in the length encoding is returned in the
*out_len parameter.

For example, if the 'loc' position in the buffer points to 0x82, 0x10, 0x22,
then 0x03 is returned in *out_len_len, and 0x1022 is returned in *length.

Returns 0 if there is an error reading the length information, 1 if successful.
*/
int get_length_info(	const unsigned char* buf, 
						unsigned int loc, 
						unsigned int buf_len,
						unsigned int *out_len_len, 
						unsigned int *out_len );

/* get_length_length
Gets the number of bytes required to DER-encode the given length.  If 'length'
is less than 0x80, then the return value is 1.  Otherwise, the return value is
1 + the minimum number of bytes to express the length in base 16.  

For example, if 'length' is 4096 (0x1000), then the function will return 3, 
because the DER encoded length octets are 0x82, 0x10, 0x00.
*/
unsigned int get_length_length( unsigned int length );

/* write_length_octets
Writes the length octets for the given length to the current position in the 
FILE.  f may not be null.
*/
void write_length_octets( FILE *f, unsigned int length );

/* create
Creates a cross certificate pair from two files that have already been read
into buffers.  The file name of the resulting cross certificate pair 
corresponds to the file names passed in as fname1 and fname2.  The forward
certificate in the cross certificate pair is assumed to be cert1, while the
reverse cert is cert2.  

For example, if "cert1.der" was read into cert1, and "cert2.der" was read into
cert2, this function would create the cross certificate pair named
"cert1_TO_cert2.xcert", with cert1.der as the forward (first) certificate, and
cert2.der as the reverse (second) certificate.

This function writes to standard out.
*/
void create_cross_cert( char* fname1, char* fname2, 
						unsigned char *cert1, unsigned int cert1_len,
						unsigned char *cert2, unsigned int cert2_len );

/* split_cross_cert
Splits a cross certificate pair into the component certificates.  The first 
certificate in the sequence is written to "forward.der".  The second is written
to "reverse.der".  

This function writs to standard out.
*/
void split_cross_cert( const unsigned char* cc, const unsigned int flength );




/*

IMPLEMENTATION

*/





/* reads a file to a buffer.  See function comments at the top of the file. */
int read_file( const char* fname, unsigned char **buf, unsigned int *length )
{
	FILE *file = NULL;
	unsigned int tmp=0;

	if( fname == NULL || buf == NULL || length == NULL ) 
	{ 
		printf("Invalid null parameter passed to read_file\n");
		return 0;
	}

	
	fopen_s( &file, fname, "rb" );
	if( file ) { 

		if( 0 != fseek( file, 0, SEEK_END ) ) {
			printf( "I/O Error in file %s\n", fname );
			fclose( file );
			return 0;			
		}

		*length = ftell(file);
		rewind(file);

		if( *length <= 0 ) {
			printf( "0 Length for file %s\n", fname );
			fclose( file );
			return 0;
		}

		/* allocate the buffer */
		*buf = (unsigned char*)malloc(*length * sizeof(unsigned char));
		if( !*buf ) {
			printf( "Memory error allocating buffer for %s contents\n", fname );
			fclose( file );
			file = NULL;
			return 0;
		}
		/* read the file */
		if( *length != (tmp=fread( *buf, sizeof(unsigned char), *length, file)) ) {
			printf( "Error reading %s: %d of %d bytes read.\n", fname, tmp, *length );
			free( *buf );
			*buf = NULL;
			*length = 0;
			fclose( file );
			file = NULL;
			return 0;
		}

		fclose( file );
		file = NULL;
		return 1;

	} else { 

		printf("File %s cannot be opened for reading\n", fname);
		return 0;

	}
}


/* gets the length information from the 'loc' position in the buffer */
int get_length_info(const unsigned char* buf, unsigned int loc, 
					unsigned int buf_len, 
					unsigned int *out_len_len, unsigned int *out_len )
{
	unsigned int tmp = 0;

	*out_len = 0;
	*out_len_len = 0;

	/* make sure 'loc' is a valid position in the buffer */
	if( loc >= buf_len ) { 
		printf("Buffer position is out of range (0x02x >= 0x02x)\n", loc, buf_len);
		return 0;
	}


	*out_len_len = 1;
	if( buf[loc] < 0x80 ) {
		/* if the current position is < 0x80, then it is a single-byte length 
		encoding */
		*out_len = buf[loc];
		return 1;
	} else {		
		*out_len_len = buf[loc] - 0x80 + 1;
		/* ensure that the length of the length encoding will not exceed the 
		boundary of the buffer */
		if( loc + *out_len_len > buf_len ) { 
			printf(				"Invalid length encoding (length octets would \
exceed the buffer boundary)\n");
			return 0;
		}

		/* bit-shift the length octets so they can be read into the length.
		This assumes the length can fit into the width of a long.  For a cross-
		certificate pair, this shouldn't be a problem */
		for( tmp = 1; tmp < *out_len_len; tmp++ )
		{
			*out_len += (buf[loc+tmp] << (8 * ((*out_len_len) - tmp - 1)));
		}
		return 1;
	}
}


// gets the number of bytes required for length octets
unsigned int get_length_length( unsigned int length )
{
	unsigned int rval = 1;
	if( length < 0x80 ) { 
		return rval;
	}	
	/* count the number of times 'length' must be bit-shifted until there are
	no more non-zero bits in the number */
	while( length ) {
		rval++;
		length = (length >> 8);
	}
	return rval;
}


// writes length octets for the given length at the current position in a file
void write_length_octets( FILE *f, unsigned int length )
{
	unsigned char *buf = NULL;
	unsigned int buf_len = 0;
	unsigned int buf_len_cpy = 0;
	buf_len = get_length_length(length);
	buf_len_cpy = buf_len;
	buf = malloc(sizeof(unsigned char)*buf_len);

	if( buf_len == 1 ) { 
		buf[0] = 0x80 + length;
		printf("Writing %0x04x\n", buf[0] );
		fwrite(buf, sizeof(unsigned char), buf_len, f);
	} else {
		
		buf[0] = 0x80 + (buf_len - 1);
		while( --buf_len >= 1 ) {
			buf[buf_len] = length & 0xFF;
			length = (length >> 8);
		}
		fwrite(buf, sizeof(unsigned char), buf_len_cpy, f);		
	}
	free( buf );
	buf = NULL;
	buf_len = 0;
}



// Creates a cross certificate pair from two certificates
void create_cross_cert( char* fname1, char* fname2, 
						 unsigned char *cert1, unsigned int cert1_len, 
						 unsigned char *cert2, unsigned int cert2_len ) {

	char *period = NULL;
	char nt = '\0';
	char output_fname[OUTPUT_FNAME_LEN];
	unsigned char tag; 
	unsigned int cert1_idx=0, cert2_idx=0;
	unsigned int forward_len=0,backward_len=0,total_len=0;
	unsigned int tag1=0, tag2=0;
	unsigned int npos = 0, buf_length=0;
	unsigned int fwd_len_len, rev_len_len;
	FILE *outfile=NULL;


	if( !cert1 || !cert2 || cert1_len <= 0 || cert2_len <= 0 ) {
		return;
	}

	// initialize variables
	memset( output_fname, 0, OUTPUT_FNAME_LEN );

	// create the output file name
	// first, replace the . with a \0 to consider the first filename as the 
	// root only, without the extension.  If the filename is longer than 20
	// chars after this truncation, truncate it further to  just 20 chars.
	period = strstr(fname1, "." );
	if( period != NULL ) {
		period[0] = '\0';
	} 
	if( strlen(fname1) > 20 ) {
		fname1[20] = '\0';
	}
	// do the same for the second file
	period = strstr(fname2, "." );
	if( period != NULL ) {
		period[0] = '\0';	
	} 
	if( strlen(fname2) > 20 ) {
		fname2[20] = '\0';
	}

	// create the output filename
	sprintf_s(output_fname, OUTPUT_FNAME_LEN * sizeof(char), "%s_TO_%s.xcert", fname1, fname2 );

	// print a status message
	printf( "Creating %s......\n", output_fname );

	// the length of each sequence element = 
	// 1 + length of the length octets + length of the certificate
	forward_len = cert1_len;
	backward_len = cert2_len;

	// get the length needed for the length octets
	fwd_len_len = get_length_length(forward_len);
	rev_len_len = get_length_length(backward_len);

	// add the length needed for the length octets to the forward and reverse 
	// length
	forward_len += fwd_len_len;
	backward_len += rev_len_len;

	// add 1 for the A0 or A1 tag
	forward_len += 1;
	backward_len += 1;

	// the length of the total thing = 1 + length octets for (length1 + length2) 
	// + (length1 + length2)
	total_len = forward_len + backward_len;

	outfile = NULL;
	fopen_s(&outfile, output_fname, "wb" );
	if( outfile ) {
		tag = 0x30;
		// write the 0x30 tag
		fwrite(&tag, sizeof(unsigned char), 1, outfile);
		write_length_octets(outfile, total_len);

		// write the forward cert enumerated tag
		tag = 0xA0;
		fwrite(&tag, sizeof(unsigned char), 1, outfile);
		// write the length octets for the forward cert
		write_length_octets(outfile, cert1_len);
		// write the forward cert
		fwrite(cert1, sizeof(unsigned char), cert1_len, outfile);

		// write the reverse cert enumerated tag
		tag = 0xA1;
		fwrite(&tag, sizeof(unsigned char), 1, outfile);
		// write the length octets for the reverse cert
		write_length_octets(outfile, cert2_len);
		// write the reverse cert
		fwrite(cert2, sizeof(unsigned char), cert2_len, outfile);
	} else {
		printf("Couldn't open file %s\n", output_fname);
	}
}

// splits a cross certificate pair into two certificates, 'forward.der' and
// 'reverse.der'
void split_cross_cert( const unsigned char* cc, const unsigned int flength )
{
	unsigned int length=0, length_of_length=0;
	unsigned int loc = 0;
	FILE *cert = NULL;

	if( cc == NULL || flength <= 0 ) { 
		return;
	}

	/* Opening tag should be 0x30 (ASN sequence) */
	if( cc[0] != 0x30 ) { 
		printf("Unexpected opening tag found (0x%02x)...expected 0x30\n", cc[0]);
	}


	if( get_length_info( cc, 1, flength, &length_of_length, &length ) )
	{
	
		if( length + length_of_length != flength - 1 ) {
			printf("Length of outer sequence does not match encoded length\n");
			return;
		}

		loc = 1 + length_of_length;
		
		// check for 0xA0 DER enumeration tag 
		if( cc[loc] != 0xA0 ) { 
			printf("Invalid tag found for forward cert: 0x%02x (expected 0xA0)\n", cc[loc]);
			return;
		}

		/// write the forward cert to a file
		loc++;
		if( get_length_info( cc, loc, flength, &length_of_length, &length ) )
		{
			loc += length_of_length;
			printf("Writing forward cert...\n");
			fopen_s( &cert, "forward.der", "wb" );
			if( cert ) {
				fwrite( &(cc[loc]), sizeof(unsigned char), length, cert);
				fclose(cert);
			} else { 
				printf("Error opening 'forward.der' for writing\n");
				return;
			}

			// move past the first cert
			loc += length;
			// check for 0xA0 DER enumeration tag
			if( cc[loc] != 0xA1 ) { 
				printf("Invalid tag found for forward cert: 0x%02x (expected 0xA1)\n", cc[loc]);
				return;
			}
			/// write the reverse cert to a file
			loc++;
			get_length_info( cc, loc, flength, &length_of_length, &length );
			loc += length_of_length;
			printf("Writing reverse cert...\n");
			fopen_s( &cert, "reverse.der", "wb" );
			if( cert ) {
				fwrite( &(cc[loc]), sizeof(unsigned char), length, cert);
				fclose(cert);
			} else { 
				printf("Error opening 'forward.der' for writing\n");
				return;
			}
		}
	}
}


int main( int argc, char **argv ) {
	FILE *file = NULL;
	unsigned char* file1_contents = NULL;
	unsigned char* file2_contents = NULL;
	long file1_len = 0;
	long file2_len = 0;
	long tmp=0;
	
	if(argc == 3 && 0 == strcmp(argv[1], "-split")) { 
		// split a cross certificate
		if( read_file( argv[2], &file1_contents, &file1_len ) ) {

			split_cross_cert( file1_contents, file1_len );

			free(file1_contents);
			file1_contents = NULL;
		}
	} else if( argc == 3 ) {

		// create cross certificates
		// read the first file
		if( read_file( argv[1], &file1_contents, &file1_len ) && 
			read_file( argv[2], &file2_contents, &file2_len ) )
		{
			create_cross_cert( argv[1], argv[2], file1_contents, file1_len, file2_contents, file2_len );
			create_cross_cert( argv[2], argv[1], file2_contents, file2_len, file1_contents, file1_len );

			free( file1_contents );
			free( file2_contents );
			file1_contents = NULL;
			file2_contents = NULL;
		}

	} else { 		
		printf( "Proper Use: \n");
		printf( "  CrossCert.exe (cert1) (cert2)\n" );
		printf( "  CrossCert.exe -split (cross cert)\n" );
	}

	return 0;
}
