/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2005  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "AsyncConnection.h"
#include "AsyncConnectionListener.h"
#include "Reactor.h"
#include "IntrusivePtr.h"
#include "DnsCache.h"
#include <ace/INET_Addr.h>
#include <ace/OS_NS_sys_socket.h>
#include <ace/os_include/netinet/os_tcp.h> // FOR TCP_NODELAY
#include <cassert>

using namespace std;

AsyncConnection::AsyncConnection()
:	m_pListener(0),
	m_status(NOT_CONNECTED),
	m_addr("", -1),
	m_pReactor(0),
	m_refCounter(0)
{
}

AsyncConnection::~AsyncConnection()
{
	abort();
}

void
AsyncConnection::initiate(Reactor& reactor,
	InetAddr const& addr, ACE_Time_Value const* timeout)
{
	int& errno_ref = errno;
	abort();
	onNewConnection();
	m_status = IN_PROGRESS;
	m_addr = addr;
	ACE_INET_Addr resolved_addr;
	if (!DnsCache::instance()->get(addr, resolved_addr)) {
		if (addr.resolve(resolved_addr)) {
			DnsCache::instance()->put(addr, resolved_addr);
		} else {
			onConnFailed(AsyncConnectionListener::FAIL_DNS, 0);
			return;
		}
	}
	
	if (m_peer.open(SOCK_STREAM, resolved_addr.get_type(), 0, 0) == -1
	    || m_peer.enable(ACE_NONBLOCK) == -1) {
		m_peer.close();
		onConnFailed(AsyncConnectionListener::FAIL_OTHER, errno_ref);
		return;
	}
	
	m_pReactor = &reactor;
	try {
		IntrusivePtr<EventHandlerBase> handler(this);
		m_handlerId = reactor.registerHandler(
			m_peer.get_handle(), handler, Reactor::ALL_EVENTS
		);
		m_timerId = reactor.registerTimer(handler, timeout);
	} catch (Reactor::Exception&) {
		abort();
		onConnFailed(AsyncConnectionListener::FAIL_OTHER, 0);
		return;
	}
	
	sockaddr* sa = reinterpret_cast<sockaddr*>(resolved_addr.get_addr());
	if (ACE_OS::connect(m_peer.get_handle(), sa, resolved_addr.get_size()) == -1) {
		if (errno_ref != EINPROGRESS && errno_ref != EWOULDBLOCK) {
			abort();
			onConnFailed(AsyncConnectionListener::FAIL_OTHER, errno_ref);
			return;
		}
	}
}

void
AsyncConnection::abort()
{
	unregisterHandlers();
	m_peer.close();
	m_status = NOT_CONNECTED;
}

void
AsyncConnection::setTcpNoDelay(ACE_SOCK_Stream& strm)
{
#if !defined(ACE_LACKS_TCPNODELAY)
	int one = 1;
	strm.set_option(ACE_IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
#endif
}

void
AsyncConnection::handleRead(ACE_HANDLE)
{
	handleCompletion();
}

void
AsyncConnection::handleWrite(ACE_HANDLE)
{
	handleCompletion();
}

void
AsyncConnection::handleExcept(ACE_HANDLE)
{
	handleCompletion();
}

void
AsyncConnection::handleTimeout(ReactorTimerId const&)
{
	abort();
	onConnFailed(AsyncConnectionListener::FAIL_TIMEOUT, ETIMEDOUT);
}

void
AsyncConnection::ref()
{
	++m_refCounter;
}

bool
AsyncConnection::unref()
{
	if (--m_refCounter == 0) {
		m_handlerId = ReactorHandlerId();
		m_timerId = ReactorTimerId();
		m_pReactor = 0;
	}
	return true;
}

void
AsyncConnection::unregisterHandlers()
{
	if (m_pReactor) {
		if (m_handlerId) {
			m_pReactor->unregisterHandler(m_handlerId);
			m_handlerId = ReactorHandlerId();
		}
		if (m_timerId) {
			m_pReactor->unregisterTimer(m_timerId);
			m_timerId = ReactorTimerId();
		}
		m_pReactor = 0;
	}
}

void
AsyncConnection::handleCompletion()
{
	typedef AsyncConnectionListener ASL;
	ASL::FailureCode fcode = ASL::FAIL_OTHER;
	int err = 0;
	int errlen = sizeof(err);
	if (m_peer.get_option(SOL_SOCKET, SO_ERROR, &err, &errlen) == 0) {
		switch (err) {
			case 0:
				unregisterHandlers();
				onConnEstablished();
				return;
			case ETIMEDOUT:
				fcode = ASL::FAIL_TIMEOUT;
				break;
			case ECONNREFUSED:
				fcode = ASL::FAIL_REFUSED;
				break;
			case ENETUNREACH:
			case EHOSTUNREACH:
			case EHOSTDOWN:
				fcode = ASL::FAIL_UNREACHABLE;
				break;
		}
	} else {
		err = errno;
	}
	abort();
	onConnFailed(fcode, err);
}

void
AsyncConnection::onConnEstablished()
{
	m_status = ESTABLISHED;
	if (m_pListener) {
		m_pListener->connectionEstablished();
	}
}

void
AsyncConnection::onConnFailed(AsyncConnectionListener::FailureCode fcode, int errnum)
{
	m_status = NOT_CONNECTED;
	if (m_pListener) {
		m_pListener->connectionFailed(fcode, errnum);
	}
}
