/*
 * WinPR: Windows Portable Runtime
 * Stream Utils
 *
 * Copyright 2011 Vic Lee
 * Copyright 2012 Marc-Andre Moreau <marcandre.moreau@gmail.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <assert.h>
#include <winpr/crt.h>
#include <winpr/stream.h>

BOOL Stream_EnsureCapacity(wStream* s, size_t size)
{
	WINPR_ASSERT(s);
	if (s->capacity >= size)
		return TRUE;

	const size_t increment = 128ull;
	const size_t old_capacity = s->capacity;
	const size_t new_capacity = size + increment - size % increment;
	const size_t position = Stream_GetPosition(s);

	BYTE* new_buf = NULL;
	if (!s->isOwner)
	{
		new_buf = (BYTE*)malloc(new_capacity);
		if (!new_buf)
			return FALSE;

		CopyMemory(new_buf, s->buffer, s->capacity);
		s->isOwner = TRUE;
	}
	else
	{
		new_buf = (BYTE*)realloc(s->buffer, new_capacity);
		if (!new_buf)
			return FALSE;
	}

	s->buffer = new_buf;
	s->capacity = new_capacity;
	s->length = new_capacity;
	ZeroMemory(&s->buffer[old_capacity], s->capacity - old_capacity);

	Stream_SetPosition(s, position);
	return TRUE;
}

BOOL Stream_EnsureRemainingCapacity(wStream* s, size_t size)
{
	if (Stream_GetPosition(s) + size > Stream_Capacity(s))
		return Stream_EnsureCapacity(s, Stream_Capacity(s) + size);
	return TRUE;
}

wStream* Stream_New(BYTE* buffer, size_t size)
{
	wStream* s;

	if (!buffer && !size)
		return NULL;

	s = malloc(sizeof(wStream));
	if (!s)
		return NULL;

	if (buffer)
		s->buffer = buffer;
	else
		s->buffer = (BYTE*)malloc(size);

	if (!s->buffer)
	{
		free(s);
		return NULL;
	}

	s->pointer = s->buffer;
	s->capacity = size;
	s->length = size;

	s->pool = NULL;
	s->count = 0;
	s->isAllocatedStream = TRUE;
	s->isOwner = TRUE;
	return s;
}

void Stream_StaticInit(wStream* s, BYTE* buffer, size_t size)
{
	assert(s);
	assert(buffer);

	s->buffer = s->pointer = buffer;
	s->capacity = s->length = size;
	s->pool = NULL;
	s->count = 0;
	s->isAllocatedStream = FALSE;
	s->isOwner = FALSE;
}

void Stream_Free(wStream* s, BOOL bFreeBuffer)
{
	if (s)
	{
		if (bFreeBuffer && s->isOwner)
			free(s->buffer);

		if (s->isAllocatedStream)
			free(s);
	}
}

BOOL Stream_CheckAndLogRequiredLengthEx(const char* tag, DWORD level, wStream* s, UINT64 len,
                                        const char* fmt, ...)
{
	const size_t actual = Stream_GetRemainingLength(s);

	if (actual < len)
	{
		va_list args;

		va_start(args, fmt);
		Stream_CheckAndLogRequiredLengthExVa(tag, level, s, len, fmt, args);
		va_end(args);

		return FALSE;
	}
	return TRUE;
}

BOOL Stream_CheckAndLogRequiredLengthExVa(const char* tag, DWORD level, wStream* s, UINT64 len,
                                          const char* fmt, va_list args)
{
	const size_t actual = Stream_GetRemainingLength(s);

	if (actual < len)
		return Stream_CheckAndLogRequiredLengthWLogExVa(WLog_Get(tag), level, s, len, fmt, args);
	return TRUE;
}

BOOL Stream_CheckAndLogRequiredLengthWLogEx(wLog* log, DWORD level, wStream* s, UINT64 len,
                                            const char* fmt, ...)
{
	const size_t actual = Stream_GetRemainingLength(s);

	if (actual < len)
	{
		va_list args;

		va_start(args, fmt);
		Stream_CheckAndLogRequiredLengthWLogExVa(log, level, s, len, fmt, args);
		va_end(args);

		return FALSE;
	}
	return TRUE;
}

BOOL Stream_CheckAndLogRequiredLengthWLogExVa(wLog* log, DWORD level, wStream* s, UINT64 len,
                                              const char* fmt, va_list args)
{
	const size_t actual = Stream_GetRemainingLength(s);

	if (actual < len)
	{
		char prefix[1024] = { 0 };

		vsnprintf(prefix, sizeof(prefix), fmt, args);

		WLog_Print(log, level, "[%s] invalid length, got %" PRIuz ", require at least %" PRIu64,
		           prefix, actual, len);
		winpr_log_backtrace_ex(log, level, 20);
		return FALSE;
	}
	return TRUE;
}
