/* $Id$ */

/*
 *
 * Copyright (C) 2005 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "async_ssl.h"

#undef output

aiossl::aiossl (int fd, size_t rbufsize)
  : aios (fd, rbufsize), ssl (NULL), bss (NULL), inbss (NULL),
    need_handshake (true), iocblock (false), finlock (false)
{
}

aiossl::~aiossl ()
{
  if (ssl)
    SSL_free (ssl);
}

void
aiossl::finalize ()
{
  if (finlock)
    return;
  else if (!ssl)
    aios::finalize ();
  else if (fd < 0 || sent_shutdown || (err && err != ETIMEDOUT))
    delete this;
  else if (err == ETIMEDOUT) {
    weof = true;
    dooutput ();
    delete this;
  }
  else {
    weof = true;
    if (!gettimeout ())
      settimeout (60);
    finlock = true;
    iocb ();
    finlock = false;
    if (fd < 0 || sent_shutdown || err)
      delete this;
  }
}

int
aiossl::doerr (int err, int silent)
{
  switch (int err2 = SSL_get_error (ssl, err)) {
  case SSL_ERROR_NONE:
    return err;
  case SSL_ERROR_ZERO_RETURN:
    errno = EPIPE;
    return 0;
  case SSL_ERROR_WANT_READ:
    if (inbss) {
      assert (!BIO_ctrl (inbss, BIO_CTRL_INFO, 0, NULL));
      inbss = NULL;
      bss = BIO_new_socket (fd, BIO_NOCLOSE);
      SSL_set_bio (ssl, bss, bss);
    }
    if (!rcbset) {
      rcbset = true;
      fdcb (fd, selread, wrap (this, &aiossl::iocb));
    }
    errno = EAGAIN;
    return -1;
  case SSL_ERROR_WANT_WRITE:
    if (!wcbset) {
      wcbset = true;
      fdcb (fd, selwrite, wrap (this, &aiossl::iocb));
    }
    errno = EAGAIN;
    return -1;
  default:
    if (!silent)
      warn << "SSL error: " << ssl_err (err2) << "\n";
    errno = EIO;
    return -1;
  }
}

int
aiossl::doinput ()
{
  if (!ssl)
    return aios::doinput ();

  const iovec *iov = inb.iniov ();
  int iovcnt = inb.iniovcnt ();
  int nread = 0;
  while (iovcnt > 0) {
    int n = SSL_read (ssl, iov->iov_base, iov->iov_len);
    if (n >= implicit_cast<ssize_t> (iov->iov_len)) {
      assert (implicit_cast<size_t> (n) == iov->iov_len);
      inb.addbytes (n);
      nread += n;
      iov++;
      iovcnt--;
    }
    else if (n > 0) {
      inb.addbytes (n);
      nread += n;
      return nread;
    }
    else {
      n = doerr (n);
      return nread > 0 ? nread : n;
    }
  }
  return nread;
}

int
aiossl::dooutput ()
{
  if (!ssl)
    return aios::dooutput ();

  int ret = 0;

  if (oldout) {
    ret = oldout->output (fd);
    if (!oldout->resid ())
      oldout = NULL;
    if (ret < 0 || oldout)
      return ret;
  }

  while (outb.tosuio ()->resid ()) {
    const iovec *iov = outb.tosuio ()->iov ();
    int n = SSL_write (ssl, iov->iov_base, iov->iov_len);
    if (n > 0) {
      ret = 1;
      outb.tosuio ()->rembytes (n);
    }
    else {
      n = doerr (n);		// May set wantread or wantwrite
      return ret ? ret : n;
    }
  }
  if (weof && !outb.tosuio ()->resid () && !sent_shutdown) {
    int n = doerr (SSL_shutdown (ssl), true);
    if (n < 0 && errno == EAGAIN)
      return n;
    sent_shutdown = true;
  }
  return ret;
}

inline str
mkstr (char *name)
{
  str ret (name);
  free (name);
  return ret;
}
bool
aiossl::verify ()
{
  if (SSL_get_verify_result (ssl) != X509_V_OK)
    return false;
  X509 *cert = SSL_get_peer_certificate (ssl);
  if (!cert)
    return false;

  char buf[257] = "";

  X509_NAME *name = X509_get_subject_name (cert);
  subject_dn = mkstr (X509_NAME_oneline (name, NULL, 0));
  X509_NAME_get_text_by_NID (name, NID_commonName, buf, sizeof (buf) - 1);
  buf[sizeof (buf) - 1] = '\0';
  subject = buf;

  name = X509_get_issuer_name (cert);
  issuer_dn = mkstr (X509_NAME_oneline (name, NULL, 0));
  X509_NAME_get_text_by_NID (name, NID_commonName, buf, sizeof (buf) - 1);
  buf[sizeof (buf) - 1] = '\0';
  issuer = buf;

  X509_free (cert);
  return true;
}

void
aiossl::iocb ()
{
  if (iocblock) {
    again = true;
    return;
  }
  iocblock = true;
  ref<aios> hold = mkref (this); // Don't let this be freed under us

  assert (fd >= 0);

  if (rcbset) {
    fdcb (fd, selread, NULL);
    rcbset = false;
  }
  if (wcbset) {
    fdcb (fd, selwrite, NULL);
    wcbset = false;
  }

  if (oldout) {
    int err = dooutput ();
    if (oldout) {
      iocblock = false;
      if (err < 0)
	fail (errno);
      else if (!wcbset) {
	wcbset = true;
	fdcb (fd, selwrite, wrap (this, &aiossl::iocb));
      }
      return;
    }
  }

  if (need_handshake) {
    int err;
    if (server)
      err = SSL_accept (ssl);
    else
      err = SSL_connect (ssl);
    err = doerr (err);
    if (err <= 0) {
      if (errno != EAGAIN)
	fail (errno);
      iocblock = false;
      return;
    }
    else {
      need_handshake = false;
      cipher = SSL_get_cipher_name (ssl);
      if (debugname)
	warnx << debugname << errpref << "SSL-cipher " << cipher << "\n";
      if (verify ()) {
	if (debugname) {
	  warnx << debugname << errpref << "SSL-cert issuer=" << issuer
	       << ", subject=" << subject << "\n";
	  //warnx << debugname << errpref << "Subject: " << subject_dn << "\n";
	  //warnx << debugname << errpref << "Issuer: " << issuer_dn << "\n";
	}
	if (vcb)
	  (*vcb) ();
      }
    }
  }

  do {
    again = false;
    if (fd >= 0 && reading ())
      input ();
    if (fd >= 0 && (writing () || (weof && !sent_shutdown) || oldout))
      output ();
  } while (again);

  iocblock = false;
}

void
aiossl::setincb ()
{
  if (!ssl) {
    aios::setincb ();
    return;
  }

  if (!rcbset)
    iocb ();
}

void
aiossl::setoutcb ()
{
  if (!ssl) {
    aios::setoutcb ();
    return;
  }

  if (!wcbset)
    iocb ();
}

void
aiossl::startssl (SSL_CTX *ctx, bool s)
{
  assert (!ssl);
  assert (fd >= 0);
  fdcb (fd, selread, NULL);
  fdcb (fd, selwrite, NULL);
  wcbset = rcbset = false;

  ssl = SSL_new (ctx);
  bss = BIO_new_socket (fd, BIO_NOCLOSE);

  inbss = BIO_new (BIO_s_mem ());
  BIO_set_mem_eof_return (inbss, -1);
  while (inb.outiovcnt () > 0) {
    const iovec *iov = inb.outiov ();
    BIO_write (inbss, iov->iov_base, iov->iov_len);
    inb.rembytes (iov->iov_len);
  }

  oldout = New refcounted<suio>;
  oldout->take (outb.tosuio ());

  SSL_set_bio (ssl, inbss, bss);
  SSL_set_mode (ssl, (SSL_MODE_ENABLE_PARTIAL_WRITE
		      | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER));

  server = s;
  need_handshake = true;
  sent_shutdown = false;
  rdpref = " *=> ";
  wrpref = " <=* ";
  errpref = " *** ";
  iocb ();
}

