#include "SSL.h" NAMESPACE_UPP struct TcpSocket::SSLImp : TcpSocket::SSL { virtual bool Start(); virtual bool Wait(dword flags); virtual int Send(const void *buffer, int maxlen); virtual int Recv(void *buffer, int maxlen); virtual void Close(); TcpSocket& socket; SslContext context; ::SSL *ssl; SslCertificate cert; void SetSSLError(const char *context); void SetSSLResError(const char *context, int res); bool IsAgain(int res) const; SSLImp(TcpSocket& socket) : socket(socket) {} }; TcpSocket::SSL *TcpSocket::CreateSSLImp(TcpSocket& socket) { return new TcpSocket::SSLImp(socket); } void InitCreateSSL() { TcpSocket::CreateSSL = TcpSocket::CreateSSLImp; } INITBLOCK { InitCreateSSL(); } void TcpSocket::SSLImp::SetSSLError(const char *context) { int code; String text = SslGetLastError(code); socket.SetSockError(context, code, text); } void TcpSocket::SSLImp::SetSSLResError(const char *context, int res) { int code = SSL_get_error(ssl, res); String out; switch(code) { #define SSLERR(c) case c: out = #c; break; SSLERR(SSL_ERROR_NONE) SSLERR(SSL_ERROR_SSL) SSLERR(SSL_ERROR_WANT_READ) SSLERR(SSL_ERROR_WANT_WRITE) SSLERR(SSL_ERROR_WANT_X509_LOOKUP) SSLERR(SSL_ERROR_SYSCALL) SSLERR(SSL_ERROR_ZERO_RETURN) SSLERR(SSL_ERROR_WANT_CONNECT) #ifdef PLATFORM_WIN32 SSLERR(SSL_ERROR_WANT_ACCEPT) #endif default: out = "unknown code"; break; } socket.SetSockError(context, code, out); } bool TcpSocket::SSLImp::IsAgain(int res) const { res = SSL_get_error(ssl, res); return res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE || res == SSL_ERROR_WANT_CONNECT || res == SSL_ERROR_WANT_ACCEPT; } bool TcpSocket::SSLImp::Start() { if(!context.Create(const_cast(SSLv3_client_method()))) { SetSSLError("Start: SSL context."); return false; } if(!(ssl = SSL_new(context))) { SetSSLError("Start: SSL_new"); return false; } if(!SSL_set_fd(ssl, socket.GetSOCKET())) { SetSSLError("Start: SSL_set_fd"); return false; } int res; if(socket.mode == ACCEPT) { SSL_set_accept_state(ssl); int res = SSL_accept(ssl); if(res <= 0 && !IsAgain(res)) { SetSSLResError("Start: SSL_accept", res); return false; } } else { SSL_set_connect_state(ssl); res = SSL_connect(ssl); if(res <= 0 && !IsAgain(res)) { SetSSLResError("Start: SSL_connect", res); return false; } } cert.Set(SSL_get_peer_certificate(ssl)); return true; } bool TcpSocket::SSLImp::Wait(dword flags) { if((flags & WAIT_READ) && SSL_pending(ssl) > 0) return true; return socket.RawWait(flags); } int TcpSocket::SSLImp::Send(const void *buffer, int maxlen) { int res = SSL_write(ssl, (const char *)buffer, maxlen); if(IsAgain(res)) return 0; if(res <= 0) { SetSSLResError("SSL_write", res); return 0; } return res; } int TcpSocket::SSLImp::Recv(void *buffer, int maxlen) { int res = SSL_read(ssl, (char *)buffer, maxlen); if(IsAgain(res)) return 0; if(res == 0) { socket.is_eof = true; if(SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN) return 0; } if(res <= 0) { SetSSLResError("SSL_read", res); return 0; } return res; } void TcpSocket::SSLImp::Close() { SSL_shutdown(ssl); socket.RawClose(); SSL_free(ssl); } #if 0 class SSLSocketData : public TcpSocket::Data { public: SSLSocketData(SslContext& context); virtual ~SSLSocketData(); bool OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, int timeout, bool is_blocking); bool Secure(); bool OpenAccept(SOCKET connection, bool nodelay, bool blocking); virtual int GetKind() const { return SOCKKIND_SSL; } virtual bool Peek(int timeout_msec, bool write); virtual int Read(void *buf, int amount); virtual int Write(const void *buf, int amount); virtual bool Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec); virtual bool Close(int timeout_msec); virtual Value GetInfo(String info) const; void SetSSLError(const char *context); void SetSSLResError(const char *context, int res); public: SslContext& ssl_context; SSL *ssl; SslCertificate cert; }; SSLSocketData::SSLSocketData(SslContext& ssl_context) : ssl_context(ssl_context) { SSLInit().AddThread(); ssl = NULL; } SSLSocketData::~SSLSocketData() { Close(0); } void SSLSocketData::SetSSLError(const char *context) { if(sock) { int code; String text = SSLGetLastError(code); SetSockError(context, code, text); } } void SSLSocketData::SetSSLResError(const char *context, int res) { if(sock) { int code = SSL_get_error(ssl, res); String out; switch(code) { #define SSLERR(c) case c: out = #c; break; SSLERR(SSL_ERROR_NONE) SSLERR(SSL_ERROR_SSL) SSLERR(SSL_ERROR_WANT_READ) SSLERR(SSL_ERROR_WANT_WRITE) SSLERR(SSL_ERROR_WANT_X509_LOOKUP) SSLERR(SSL_ERROR_SYSCALL) SSLERR(SSL_ERROR_ZERO_RETURN) SSLERR(SSL_ERROR_WANT_CONNECT) #ifdef PLATFORM_WIN32 SSLERR(SSL_ERROR_WANT_ACCEPT) #endif default: out = "unknown code"; break; } SetSockError(context, code, out); } } bool SSLSocketData::Peek(int timeout_msec, bool write) { if(ssl && !write && SSL_pending(ssl) > 0) return true; return Data::Peek(timeout_msec, write); } int SSLSocketData::Read(void *buf, int amount) { if(!ssl) return Data::Read(buf, amount); int res = SSL_read(ssl, (char *)buf, amount); if(res == 0) { is_eof = true; if(SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN) return 0; } if(res <= 0) SetSSLResError("SSL_read", res); #ifndef NOFAKEERROR if(fake_error && res > 0) { if((fake_error -= res) <= 0) { fake_error = 0; SetSockError("SSL_read", 0, "fake error"); return -1; } else RLOG("SSLSocketData::Read: fake error after " << fake_error); } #endif return res; } int SSLSocketData::Write(const void *buf, int amount) { if(!ssl) return Data::Write(buf, amount); int res = SSL_write(ssl, (const char *)buf, amount); if(res <= 0) SetSSLResError("SSL_write", res); return res; } bool SSLSocketData::OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, int timeout, bool blocking) { return Data::OpenClient(host, port, nodelay, my_addr, timeout, /*blocking*/true); } bool SSLSocketData::Secure() { if(!(ssl = SSL_new(ssl_context))) { SetSSLError("OpenClient / SSL_new"); return false; } if(!SSL_set_fd(ssl, socket)) { SetSSLError("OpenClient / SSL_set_fd"); return false; } SSL_set_connect_state(ssl); int res = SSL_connect(ssl); if(res <= 0) { SetSSLResError("OpenClient / SSL_connect", res); return false; } cert.Set(SSL_get_peer_certificate(ssl)); return true; } bool SSLSocketData::OpenAccept(SOCKET conn, bool nodelay, bool blocking) { Attach(conn, nodelay, blocking); if(!(ssl = SSL_new(ssl_context))) { SetSSLError("Accept / SSL_new"); return false; } if(!SSL_set_fd(ssl, socket)) { SetSSLError("Accept / SSL_set_fd"); return false; } SSL_set_accept_state(ssl); int res = SSL_accept(ssl); if(res <= 0) { SetSSLResError("Accept / SSL_accept", res); return false; } cert.Set(SSL_get_peer_certificate(ssl)); return true; } bool SSLSocketData::Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec) { SOCKET connection = AcceptRaw(ipaddr, timeout_msec); if(connection == INVALID_SOCKET) return false; One data = new SSLSocketData(ssl_context); if(!data->OpenAccept(connection, nodelay, is_blocking)) return false; socket.Attach(-data); return true; } bool SSLSocketData::Close(int timeout_msec) { if(ssl) SSL_shutdown(ssl); bool res = Data::Close(timeout_msec); if(ssl) { SSL_free(ssl); ssl = NULL; } return res; } Value SSLSocketData::GetInfo(String info) const { if(info == SSLInfoCipher()) return SSL_get_cipher(ssl); if(info == SSLInfoCertAvail()) return cert.IsEmpty() ? 0 : 1; if(info == SSLInfoCertVerified()) return SSL_get_verify_result(ssl) == X509_V_OK ? 1 : 0; if(info == SSLInfoCertSubjectName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSubjectName(); if(info == SSLInfoCertIssuerName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetIssuerName(); if(info == SSLInfoCertNotBefore()) return cert.IsEmpty() ? Date(Null) : cert.GetNotBefore(); if(info == SSLInfoCertNotAfter()) return cert.IsEmpty() ? Date(Null) : cert.GetNotAfter(); if(info == SSLInfoCertVersion()) return cert.IsEmpty() ? int(Null) : cert.GetVersion(); if(info == SSLInfoCertSerialNumber()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSerialNumber(); return Data::GetInfo(info); } bool SSLServerSocket(Socket& socket, SslContext& ssl_context, int port, bool nodelay, int listen_count, bool blocking) { One data = new SSLSocketData(ssl_context); if(!data->OpenServer(port, nodelay, listen_count, blocking)) return false; socket.Attach(-data); return true; } bool SSLClientSocket(Socket& socket, SslContext& ssl_context, const char *host, int port, bool nodelay, dword *my_addr, int timeout, bool blocking) { One data = new SSLSocketData(ssl_context); if(!data->OpenClient(host, port, nodelay, my_addr, timeout, blocking)) return false; if(!data->Secure()) return false; socket.Attach(-data); return true; } bool SSLClientSocketUnsecured(Socket& socket, SslContext& ssl_context, const char *host, int port, bool nodelay, dword *my_addr, int timeout, bool is_blocking) { One data = new SSLSocketData(ssl_context); if(data->OpenClientUnsecured(host, port, nodelay, my_addr, timeout, is_blocking)) { socket.Attach(-data); return true; } return false; } bool SSLSecureSocket(Socket& socket) { SSLSocketData *sd = dynamic_cast(~socket.data); if(!sd) return false; return sd->Secure(); } #endif END_UPP_NAMESPACE