/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1

// Put this first to avoid WIN32 build failure
#include <thrift/transport/TSocket.h>

#include <openssl/ssl.h>
#include <string>
#include <thrift/concurrency/Mutex.h>

namespace apache {
namespace thrift {
namespace transport {

class AccessManager;
class SSLContext;

enum SSLProtocol {
  SSLTLS  = 0,  // Supports SSLv2 and SSLv3 handshake but only negotiates at TLSv1_0 or later.
//SSLv2   = 1,  // HORRIBLY INSECURE!
  SSLv3   = 2,  // Supports SSLv3 only - also horribly insecure!
  TLSv1_0 = 3,  // Supports TLSv1_0 or later.
  TLSv1_1 = 4,  // Supports TLSv1_1 or later.
  TLSv1_2 = 5,  // Supports TLSv1_2 or later.
  LATEST  = TLSv1_2
};

#define TSSL_EINTR 0
#define TSSL_DATA 1

/**
 * Initialize OpenSSL library.  This function, or some other
 * equivalent function to initialize OpenSSL, must be called before
 * TSSLSocket is used.  If you set TSSLSocketFactory to use manual
 * OpenSSL initialization, you should call this function or otherwise
 * ensure OpenSSL is initialized yourself.
 */
void initializeOpenSSL();
/**
 * Cleanup OpenSSL library.  This function should be called to clean
 * up OpenSSL after use of OpenSSL functionality is finished.  If you
 * set TSSLSocketFactory to use manual OpenSSL initialization, you
 * should call this function yourself or ensure that whatever
 * initialized OpenSSL cleans it up too.
 */
void cleanupOpenSSL();

/**
 * OpenSSL implementation for SSL socket interface.
 */
class TSSLSocket : public TSocket {
public:
  ~TSSLSocket() override;
  /**
   * TTransport interface.
   */
  bool isOpen() const override;
  bool peek() override;
  void open() override;
  void close() override;
  bool hasPendingDataToRead() override;
  uint32_t read(uint8_t* buf, uint32_t len) override;
  void write(const uint8_t* buf, uint32_t len) override;
  uint32_t write_partial(const uint8_t* buf, uint32_t len) override;
  void flush() override;
  /**
  * Set whether to use client or server side SSL handshake protocol.
  *
  * @param flag  Use server side handshake protocol if true.
  */
  void server(bool flag) { server_ = flag; }
  /**
   * Determine whether the SSL socket is server or client mode.
   */
  bool server() const { return server_; }
  /**
   * Set AccessManager.
   *
   * @param manager  Instance of AccessManager
   */
  virtual void access(std::shared_ptr<AccessManager> manager) { access_ = manager; }
  /**
   * Set eventSafe flag if libevent is used.
   */
  void setLibeventSafe() { eventSafe_ = true; }
  /**
   * Determines whether SSL Socket is libevent safe or not.
   */
  bool isLibeventSafe() const { return eventSafe_; }

protected:
  /**
   * Constructor.
   */
  TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config = nullptr);
  /**
   * Constructor with an interrupt signal.
   */
  TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
            std::shared_ptr<TConfiguration> config = nullptr);
  /**
   * Constructor, create an instance of TSSLSocket given an existing socket.
   *
   * @param socket An existing socket
   */
  TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
  /**
   * Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted.
   *
   * @param socket An existing socket
   */
  TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
            std::shared_ptr<TConfiguration> config = nullptr);
   /**
   * Constructor.
   *
   * @param host  Remote host name
   * @param port  Remote port number
   */
  TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<TConfiguration> config = nullptr);
    /**
  * Constructor with an interrupt signal.
  *
  * @param host  Remote host name
  * @param port  Remote port number
  */
    TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
              std::shared_ptr<TConfiguration> config = nullptr);
  /**
   * Authorize peer access after SSL handshake completes.
   */
  virtual void authorize();
  /**
   * Initiate SSL handshake if not already initiated.
   */
  void initializeHandshake();
  /**
   * Initiate SSL handshake params.
   */
  void initializeHandshakeParams();
  /**
   * Check if  SSL handshake is completed or not.
   */
  bool checkHandshake();
  /**
   * Waits for an socket or shutdown event.
   *
   * @throw TTransportException::INTERRUPTED if interrupted is signaled.
   *
   * @return TSSL_EINTR if EINTR happened on the underlying socket
   *         TSSL_DATA  if data is available on the socket.
   */
  unsigned int waitForEvent(bool wantRead);

  bool server_;
  SSL* ssl_;
  std::shared_ptr<SSLContext> ctx_;
  std::shared_ptr<AccessManager> access_;
  friend class TSSLSocketFactory;

private:
  bool handshakeCompleted_;
  int readRetryCount_;
  bool eventSafe_;

  void init();
};

/**
 * SSL socket factory. SSL sockets should be created via SSL factory.
 * The factory will automatically initialize and cleanup openssl as long as
 * there is a TSSLSocketFactory instantiated, and as long as the static
 * boolean manualOpenSSLInitialization_ is set to false, the default.
 *
 * If you would like to initialize and cleanup openssl yourself, set
 * manualOpenSSLInitialization_ to true and TSSLSocketFactory will no
 * longer be responsible for openssl initialization and teardown.
 *
 * It is the responsibility of the code using TSSLSocketFactory to
 * ensure that the factory lifetime exceeds the lifetime of any sockets
 * it might create.  If this is not guaranteed, a socket may call into
 * openssl after the socket factory has cleaned up openssl!  This
 * guarantee is unnecessary if manualOpenSSLInitialization_ is true,
 * however, since it would be up to the consuming application instead.
 */
class TSSLSocketFactory {
public:
  /**
   * Constructor/Destructor
   *
   * @param protocol The SSL/TLS protocol to use.
   */
  TSSLSocketFactory(SSLProtocol protocol = SSLTLS);
  virtual ~TSSLSocketFactory();
  /**
   * Create an instance of TSSLSocket with a fresh new socket.
   */
  virtual std::shared_ptr<TSSLSocket> createSocket();
  /**
   * Create an instance of TSSLSocket with a fresh new socket, which is interruptable.
   */
  virtual std::shared_ptr<TSSLSocket> createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener);
  /**
   * Create an instance of TSSLSocket with the given socket.
   *
   * @param socket An existing socket.
   */
  virtual std::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket);
  /**
   * Create an instance of TSSLSocket with the given socket which is interruptable.
   *
   * @param socket An existing socket.
   */
  virtual std::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
  /**
  * Create an instance of TSSLSocket.
  *
  * @param host  Remote host to be connected to
  * @param port  Remote port to be connected to
  */
  virtual std::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port);
  /**
  * Create an instance of TSSLSocket.
  *
  * @param host  Remote host to be connected to
  * @param port  Remote port to be connected to
  */
  virtual std::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener);
  /**
   * Set ciphers to be used in SSL handshake process.
   *
   * @param ciphers  A list of ciphers
   */
  virtual void ciphers(const std::string& enable);
  /**
   * Enable/Disable authentication.
   *
   * @param required Require peer to present valid certificate if true
   */
  virtual void authenticate(bool required);
  /**
   * Load server certificate.
   *
   * @param path   Path to the certificate file
   * @param format Certificate file format
   */
  virtual void loadCertificate(const char* path, const char* format = "PEM");
  virtual void loadCertificateFromBuffer(const char* aCertificate, const char* format = "PEM");
  /**
   * Load private key.
   *
   * @param path   Path to the private key file
   * @param format Private key file format
   */
  virtual void loadPrivateKey(const char* path, const char* format = "PEM");
  virtual void loadPrivateKeyFromBuffer(const char* aPrivateKey, const char* format = "PEM");
  /**
   * Load trusted certificates from specified file.
   *
   * @param path Path to trusted certificate file
   */
  virtual void loadTrustedCertificates(const char* path, const char* capath = nullptr);
  virtual void loadTrustedCertificatesFromBuffer(const char* aCertificate, const char* aChain = nullptr);
  /**
   * Default randomize method.
   */
  virtual void randomize();
  /**
   * Override default OpenSSL password callback with getPassword().
   */
  void overrideDefaultPasswordCallback();
  /**
   * Set/Unset server mode.
   *
   * @param flag  Server mode if true
   */
  virtual void server(bool flag) { server_ = flag; }
  /**
   * Determine whether the socket is in server or client mode.
   *
   * @return true, if server mode, or, false, if client mode
   */
  virtual bool server() const { return server_; }
  /**
   * Set AccessManager.
   *
   * @param manager  The AccessManager instance
   */
  virtual void access(std::shared_ptr<AccessManager> manager) { access_ = manager; }
  static void setManualOpenSSLInitialization(bool manualOpenSSLInitialization) {
    manualOpenSSLInitialization_ = manualOpenSSLInitialization;
  }

protected:
  std::shared_ptr<SSLContext> ctx_;

  /**
   * Override this method for custom password callback. It may be called
   * multiple times at any time during a session as necessary.
   *
   * @param password Pass collected password to OpenSSL
   * @param size     Maximum length of password including NULL character
   */
  virtual void getPassword(std::string& /* password */, int /* size */) {}

private:
  bool server_;
  std::shared_ptr<AccessManager> access_;
  static concurrency::Mutex mutex_;
  static uint64_t count_;
  THRIFT_EXPORT static bool manualOpenSSLInitialization_;
  void setup(std::shared_ptr<TSSLSocket> ssl);
  static int passwordCallback(char* password, int size, int, void* data);
};

/**
 * SSL exception.
 */
class TSSLException : public TTransportException {
public:
  TSSLException(const std::string& message)
    : TTransportException(TTransportException::INTERNAL_ERROR, message) {}

  const char* what() const noexcept override {
    if (message_.empty()) {
      return "TSSLException";
    } else {
      return message_.c_str();
    }
  }
};

/**
 * Wrap OpenSSL SSL_CTX into a class.
 */
class SSLContext {
public:
  SSLContext(const SSLProtocol& protocol = SSLTLS);
  virtual ~SSLContext();
  SSL* createSSL();
  SSL_CTX* get() { return ctx_; }

private:
  SSL_CTX* ctx_;
};

/**
 * Callback interface for access control. It's meant to verify the remote host.
 * It's constructed when application starts and set to TSSLSocketFactory
 * instance. It's passed onto all TSSLSocket instances created by this factory
 * object.
 */
class AccessManager {
public:
  enum Decision {
    DENY = -1, // deny access
    SKIP = 0,  // cannot make decision, move on to next (if any)
    ALLOW = 1  // allow access
  };
  /**
   * Destructor
   */
  virtual ~AccessManager() = default;
  /**
   * Determine whether the peer should be granted access or not. It's called
   * once after the SSL handshake completes successfully, before peer certificate
   * is examined.
   *
   * If a valid decision (ALLOW or DENY) is returned, the peer certificate is
   * not to be verified.
   *
   * @param  sa Peer IP address
   * @return True if the peer is trusted, false otherwise
   */
  virtual Decision verify(const sockaddr_storage& /* sa */) noexcept { return DENY; }
  /**
   * Determine whether the peer should be granted access or not. It's called
   * every time a DNS subjectAltName/common name is extracted from peer's
   * certificate.
   *
   * @param  host Client mode: host name returned by TSocket::getHost()
   *              Server mode: host name returned by TSocket::getPeerHost()
   * @param  name SubjectAltName or common name extracted from peer certificate
   * @param  size Length of name
   * @return True if the peer is trusted, false otherwise
   *
   * Note: The "name" parameter may be UTF8 encoded.
   */
  virtual Decision verify(const std::string& /* host */,
                          const char* /* name */,
                          int /* size */) noexcept {
    return DENY;
  }
  /**
   * Determine whether the peer should be granted access or not. It's called
   * every time an IP subjectAltName is extracted from peer's certificate.
   *
   * @param  sa   Peer IP address retrieved from the underlying socket
   * @param  data IP address extracted from certificate
   * @param  size Length of the IP address
   * @return True if the peer is trusted, false otherwise
   */
  virtual Decision verify(const sockaddr_storage& /* sa */,
                          const char* /* data */,
                          int /* size */) noexcept {
    return DENY;
  }
};

typedef AccessManager::Decision Decision;

class DefaultClientAccessManager : public AccessManager {
public:
  // AccessManager interface
  Decision verify(const sockaddr_storage& sa) noexcept override;
  Decision verify(const std::string& host, const char* name, int size) noexcept override;
  Decision verify(const sockaddr_storage& sa, const char* data, int size) noexcept override;
};
}
}
}

#endif
