/*
 * Copyright (C) 2019 Sean Parkinson, wolfSSL Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include "wolfssl_common.h"

#ifndef NO_DH

#include <wolfssl/wolfcrypt/dh.h>

#include "wolfssl_diffie_hellman.h"
#include "wolfssl_util.h"

#include <utils/debug.h>

typedef struct private_wolfssl_diffie_hellman_t private_wolfssl_diffie_hellman_t;

/**
 * Private data of an wolfssl_diffie_hellman_t object.
 */
struct private_wolfssl_diffie_hellman_t {

	/**
	 * Public wolfssl_diffie_hellman_t interface.
	 */
	wolfssl_diffie_hellman_t public;

	/**
	 * Diffie-Hellman group number.
	 */
	key_exchange_method_t group;

	/**
	 * Diffie-Hellman object
	 */
	DhKey dh;

	/**
	 * Length of public values
	 */
	int len;

	/**
	 * Private key
	 */
	chunk_t priv;

	/**
	 * Public key
	 */
	chunk_t pub;

	/**
	 * Public key provided by peer
	 */
	chunk_t other;

	/**
	 * Shared secret
	 */
	chunk_t shared_secret;
};

METHOD(key_exchange_t, get_public_key, bool,
	private_wolfssl_diffie_hellman_t *this, chunk_t *value)
{
	*value = chunk_copy_pad(chunk_alloc(this->len), this->pub, 0x00);
	return TRUE;
}

METHOD(key_exchange_t, get_shared_secret, bool,
	private_wolfssl_diffie_hellman_t *this, chunk_t *secret)
{
	word32 len;
	int ret;

	if (!this->shared_secret.len)
	{
		this->shared_secret = chunk_alloc(this->len);
		PRIVATE_KEY_UNLOCK();
		ret = wc_DhAgree(&this->dh, this->shared_secret.ptr, &len,
						 this->priv.ptr, this->priv.len, this->other.ptr,
						 this->other.len);
		PRIVATE_KEY_LOCK();
		if (ret != 0)
		{
			DBG1(DBG_LIB, "DH shared secret computation failed");
			chunk_free(&this->shared_secret);
			return FALSE;
		}
		this->shared_secret.len = len;
	}
	*secret = chunk_copy_pad(chunk_alloc(this->len), this->shared_secret, 0x00);
	return TRUE;
}

METHOD(key_exchange_t, set_public_key, bool,
	private_wolfssl_diffie_hellman_t *this, chunk_t value)
{
	if (!key_exchange_verify_pubkey(this->group, value))
	{
		return FALSE;
	}
	if (wc_DhCheckPubKey(&this->dh, value.ptr, value.len) != 0)
	{
		DBG1(DBG_LIB, "DH public key value invalid");
		return FALSE;
	}
	chunk_clear(&this->other);
	this->other = chunk_clone(value);
	return TRUE;
}

#ifdef TESTABLE_KE

METHOD(key_exchange_t, set_seed, bool,
	private_wolfssl_diffie_hellman_t *this, chunk_t value, drbg_t *drbg)
{
	bool success = FALSE;
	chunk_t g;
	word32 len;
	int ret;

	chunk_clear(&this->priv);
	this->priv = chunk_clone(value);

	/* calculate public value - g^priv mod p */
	if (wolfssl_mp2chunk(&this->dh.g, &g))
	{
		len = this->pub.len;
		PRIVATE_KEY_UNLOCK();
		ret = wc_DhAgree(&this->dh, this->pub.ptr, &len, this->priv.ptr,
						 this->priv.len, g.ptr, g.len);
		PRIVATE_KEY_LOCK();
		if (ret == 0)
		{
			this->pub.len = len;
			success = TRUE;
		}
	}

	free(g.ptr);
	return success;
}

#endif /* TESTABLE_KE */

METHOD(key_exchange_t, get_method, key_exchange_method_t,
	private_wolfssl_diffie_hellman_t *this)
{
	return this->group;
}

METHOD(key_exchange_t, destroy, void,
	private_wolfssl_diffie_hellman_t *this)
{
	wc_FreeDhKey(&this->dh);
	chunk_clear(&this->pub);
	chunk_clear(&this->priv);
	chunk_clear(&this->other);
	chunk_clear(&this->shared_secret);
	free(this);
}

/**
 * Maximum private key length when generating key
 */
static int wolfssl_priv_key_size(int len)
{
	if (len <= 128)
	{
		return 21;
	}
	if (len <= 256)
	{
		return 29;
	}
	if (len <= 384)
	{
		return 34;
	}
	if (len <= 512)
	{
		return 39;
	}
	if (len <= 640)
	{
		return 42;
	}
	if (len <= 768)
	{
		return 46;
	}
	if (len <= 896)
	{
		return 49;
	}
	if (len <= 1024)
	{
		return 52;
	}
	return len / 20;
}

/**
 * Generic internal constructor
 */
static wolfssl_diffie_hellman_t *create_generic(key_exchange_method_t group,
												chunk_t g, chunk_t p)
{
	private_wolfssl_diffie_hellman_t *this;
	word32 privLen, pubLen;
	WC_RNG rng;
	int ret;

	INIT(this,
		.public = {
			.ke = {
				.get_shared_secret = _get_shared_secret,
				.set_public_key = _set_public_key,
				.get_public_key = _get_public_key,
				.get_method = _get_method,
				.destroy = _destroy,
			},
		},
		.group = group,
		.len = p.len,
	);

#ifdef TESTABLE_KE
	this->public.ke.set_seed = _set_seed;
#endif

	if (wc_InitDhKey(&this->dh) != 0)
	{
		free(this);
		return NULL;
	}

	if (wc_DhSetKey(&this->dh, p.ptr, p.len, g.ptr, g.len) != 0)
	{
		destroy(this);
		return NULL;
	}

	if (wc_InitRng(&rng) != 0)
	{
		destroy(this);
		return NULL;
	}

	this->priv = chunk_alloc(wolfssl_priv_key_size(this->len));
	this->pub = chunk_alloc(this->len);
	privLen = this->priv.len;
	pubLen = this->pub.len;
	/* generate my public and private values */
	PRIVATE_KEY_UNLOCK();
	ret = wc_DhGenerateKeyPair(&this->dh, &rng, this->priv.ptr, &privLen,
							   this->pub.ptr, &pubLen);
	PRIVATE_KEY_LOCK();
	if (ret != 0)
	{
		wc_FreeRng(&rng);
		destroy(this);
		return NULL;
	}
	this->pub.len = pubLen;
	this->priv.len = privLen;
	wc_FreeRng(&rng);

	return &this->public;
}

/*
 * Described in header
 */
wolfssl_diffie_hellman_t *wolfssl_diffie_hellman_create(
											key_exchange_method_t group, ...)
{
	diffie_hellman_params_t *params;
	chunk_t g, p;

	if (group == MODP_CUSTOM)
	{
		VA_ARGS_GET(group, g, p);
		return create_generic(group, g, p);
	}
	params = diffie_hellman_get_params(group);
	if (!params)
	{
		return NULL;
	}
	/* wolfSSL doesn't support optimized exponent sizes according to RFC 3526 */
	return create_generic(group, params->generator, params->prime);
}

#endif /* NO_DH */
