Core: SSL support for Socket, finishing touches...

git-svn-id: svn://ultimatepp.org/upp/trunk@4785 f0d560ea-af0d-0410-9eb7-867de7ffcac7
This commit is contained in:
cxl 2012-04-15 11:20:50 +00:00
parent 8aace300c4
commit eef7c20270
12 changed files with 476 additions and 394 deletions

View file

@ -4,8 +4,6 @@ NAMESPACE_UPP
#define LTIMING(x) // TIMING(x)
int msecs(int from) { return (int)GetTickCount() - from; }
#ifdef PLATFORM_WIN32
#include <mmsystem.h>
#endif

View file

@ -73,8 +73,6 @@ void CloseStdLog();
void HexDump(Stream& s, const void *ptr, int size, int maxsize = INT_MAX);
int msecs(int from = 0);
String GetTypeName(const char *type_name);
inline String GetTypeName(const ::std::type_info& tinfo) { return GetTypeName(tinfo.name()); }

View file

@ -1,3 +1,4 @@
String WwwFormat(Time tm);
String FormatIP(dword _ip);
String UrlEncode(const String& s);
@ -81,7 +82,7 @@ class TcpSocket {
struct SSL {
virtual bool Start() = 0;
virtual bool Wait(dword flags) = 0;
virtual bool Wait(dword flags, int end_time) = 0;
virtual int Send(const void *buffer, int maxlen) = 0;
virtual int Recv(void *buffer, int maxlen) = 0;
virtual void Close() = 0;
@ -92,6 +93,8 @@ class TcpSocket {
One<SSL> ssl;
One<SSLInfo> sslinfo;
String cert, pkey;
bool asn1;
struct SSLImp;
friend struct SSLImp;
@ -101,7 +104,9 @@ class TcpSocket {
friend void InitCreateSSL();
bool RawWait(dword flags);
int GetEndTime() const;
bool RawWait(dword flags, int end_time);
bool Wait(dword events, int end_time);
SOCKET AcceptRaw(dword *ipaddr, int timeout_msec);
bool Open(int family, int type, int protocol);
int RawRecv(void *buffer, int maxlen);
@ -111,9 +116,11 @@ class TcpSocket {
bool RawConnect(addrinfo *info);
void RawClose();
void ReadBuffer();
void ReadBuffer(int end_time);
int Get_();
int Peek_();
int Peek_(int end_time);
int Peek(int end_time) { return ptr < end ? *ptr : Peek_(end_time); }
void Reset();
@ -170,18 +177,21 @@ public:
int Get() { return ptr < end ? *ptr++ : Get_(); }
int Get(void *buffer, int len);
String Get(int len);
int GetAll(void *buffer, int len) { return Get(buffer, len) == len; }
String GetAll(int len);
String GetLine(int maxlen = 2000000);
int Put(const char *s, int len);
int Put(const String& s) { return Put(s.Begin(), s.GetLength()); }
bool PutAll(const char *s, int len) { return Put(s, len) == len; }
bool PutAll(const String& s) { return Put(s) == s.GetCount(); }
bool GetAll(void *buffer, int len);
String GetAll(int len);
String GetLine(int maxlen = 65536);
bool PutAll(const char *s, int len);
bool PutAll(const String& s);
bool StartSSL();
bool IsSSL() const { return ssl; }
bool SSLHandshake();
void SSLCertificate(const String& cert, const String& pkey, bool asn1);
const SSLInfo *GetSSLInfo() const { return ~sslinfo; }
TcpSocket& Timeout(int ms) { timeout = ms; return *this; }

View file

@ -2,6 +2,19 @@
NAMESPACE_UPP
String WwwFormat(Time tm)
{
static const char *dayofweek[] =
{ "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat" };
static const char *month[] =
{ "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" };
return String().Cat()
<< dayofweek[DayOfWeek(tm)] << ", "
<< (int)tm.day << ' ' << month[tm.month - 1]
<< ' ' << (int)tm.year
<< ' ' << Sprintf("%2d:%02d:%02d +0100", tm.hour, tm.minute, tm.second);
}
String FormatIP(dword _ip)
{
byte ip[4];

View file

@ -130,7 +130,7 @@ private:
String SslGetLastError(int& code);
String SslGetLastError();
String SslToString(X509_NAME *name);
Date ASN1ToDate(ASN1_STRING *time);
String ASN1ToString(ASN1_STRING *s);
Date Asn1ToDate(ASN1_STRING *time);
String Asn1ToString(ASN1_STRING *s);
END_UPP_NAMESPACE

View file

@ -6,7 +6,7 @@ NAMESPACE_UPP
struct TcpSocket::SSLImp : TcpSocket::SSL {
virtual bool Start();
virtual bool Wait(dword flags);
virtual bool Wait(dword flags, int end_time);
virtual int Send(const void *buffer, int maxlen);
virtual int Recv(void *buffer, int maxlen);
virtual void Close();
@ -109,7 +109,8 @@ bool TcpSocket::SSLImp::Start()
SetSSLError("Start: SSL context.");
return false;
}
// context.VerifyPeer();
if(socket.cert.GetCount())
context.UseCertificate(socket.cert, socket.pkey, socket.asn1);
if(!(ssl = SSL_new(context))) {
SetSSLError("Start: SSL_new");
return false;
@ -157,12 +158,12 @@ dword TcpSocket::SSLImp::Handshake()
return 0;
}
bool TcpSocket::SSLImp::Wait(dword flags)
bool TcpSocket::SSLImp::Wait(dword flags, int end_time)
{
LLOG("SSL Wait");
if((flags & WAIT_READ) && SSL_pending(ssl) > 0)
return true;
return socket.RawWait(flags);
return socket.RawWait(flags, end_time);
}
int TcpSocket::SSLImp::Send(const void *buffer, int maxlen)

View file

@ -111,13 +111,13 @@ String SslCertificate::GetIssuerName() const
Date SslCertificate::GetNotBefore() const
{
ASSERT(!IsEmpty());
return ASN1ToDate(X509_get_notBefore(cert));
return Asn1ToDate(X509_get_notBefore(cert));
}
Date SslCertificate::GetNotAfter() const
{
ASSERT(!IsEmpty());
return ASN1ToDate(X509_get_notAfter(cert));
return Asn1ToDate(X509_get_notAfter(cert));
}
int SslCertificate::GetVersion() const
@ -129,7 +129,7 @@ int SslCertificate::GetVersion() const
String SslCertificate::GetSerialNumber() const
{
ASSERT(!IsEmpty());
return ASN1ToString(X509_get_serialNumber(cert));
return Asn1ToString(X509_get_serialNumber(cert));
}
SslContext::SslContext(SSL_CTX *c)
@ -186,7 +186,7 @@ String SslToString(X509_NAME *name)
return X509_NAME_oneline(name, buffer, sizeof(buffer));
}
Date ASN1ToDate(ASN1_STRING *time)
Date Asn1ToDate(ASN1_STRING *time)
{
if(!time) return Null;
int digit = 0;
@ -200,7 +200,7 @@ Date ASN1ToDate(ASN1_STRING *time)
return Date(year2 + (year2 < 90 ? 2000 : 1900), month, day);
}
String ASN1ToString(ASN1_STRING *s)
String Asn1ToString(ASN1_STRING *s)
{
return String(s->data, s->length);
}

View file

@ -272,6 +272,7 @@ TcpSocket::TcpSocket()
Reset();
timeout = Null;
waitstep = 20;
asn1 = false;
}
bool TcpSocket::Open(int family, int type, int protocol)
@ -343,9 +344,14 @@ bool TcpSocket::Accept(TcpSocket& ls)
Close();
Init();
Reset();
if(timeout && !ls.WaitRead())
return false;
ASSERT(ls.IsOpen());
if(timeout) {
int h = ls.GetTimeout();
bool b = ls.Timeout(timeout).Wait(WAIT_READ, GetEndTime());
ls.Timeout(h);
if(!b)
return false;
}
if(!Open(ls.ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0))
return false;
socket = accept(ls.GetSOCKET(), NULL, NULL);
@ -519,12 +525,11 @@ String TcpSocket::GetHostName()
return buffer;
}
bool TcpSocket::RawWait(dword flags)
bool TcpSocket::RawWait(dword flags, int end_time)
{
LLOG("Wait(" << timeout << ", " << flags << ")");
LLOG("Wait(" << msecs() << " - " << end_time << ", " << flags << ")");
if((flags & WAIT_READ) && ptr != end)
return true;
int end_time = msecs() + timeout;
if(socket == INVALID_SOCKET)
return false;
for(;;) {
@ -564,9 +569,19 @@ bool TcpSocket::RawWait(dword flags)
}
}
bool TcpSocket::Wait(dword flags, int end_time)
{
return ssl ? ssl->Wait(flags, end_time) : RawWait(flags, end_time);
}
int TcpSocket::GetEndTime() const
{
return IsNull(timeout) ? INT_MAX : msecs() + timeout;
}
bool TcpSocket::Wait(dword flags)
{
return ssl ? ssl->Wait(flags) : RawWait(flags);
return Wait(flags, GetEndTime());
}
int TcpSocket::Put(const char *s, int length)
@ -579,8 +594,9 @@ int TcpSocket::Put(const char *s, int length)
return 0;
done = 0;
bool peek = false;
int end_time = GetEndTime();
while(done < length) {
if(peek && !WaitWrite())
if(peek && !Wait(WAIT_WRITE, end_time))
return done;
peek = false;
int count = Send(s + done, length - done);
@ -595,6 +611,26 @@ int TcpSocket::Put(const char *s, int length)
return done;
}
bool TcpSocket::PutAll(const char *s, int len)
{
if(Put(s, len) != len) {
if(!IsError())
SetSockError("GePutAll", -1, "timeout");
return false;
}
return true;
}
bool TcpSocket::PutAll(const String& s)
{
if(Put(s) != s.GetCount()) {
if(!IsError())
SetSockError("GePutAll", -1, "timeout");
return false;
}
return true;
}
int TcpSocket::RawRecv(void *buf, int amount)
{
int res = recv(socket, (char *)buf, amount, 0);
@ -619,10 +655,10 @@ int TcpSocket::Recv(void *buffer, int maxlen)
return ssl ? ssl->Recv(buffer, maxlen) : RawRecv(buffer, maxlen);
}
void TcpSocket::ReadBuffer()
void TcpSocket::ReadBuffer(int end_time)
{
ptr = end = buffer;
if(WaitRead())
if(Wait(WAIT_READ, end_time))
end = buffer + Recv(buffer, BUFFERSIZE);
}
@ -630,16 +666,21 @@ int TcpSocket::Get_()
{
if(!IsOpen() || IsError() || IsEof() || IsAbort())
return -1;
ReadBuffer();
ReadBuffer(GetEndTime());
return ptr < end ? *ptr++ : -1;
}
int TcpSocket::Peek_(int end_time)
{
if(!IsOpen() || IsError() || IsEof() || IsAbort())
return -1;
ReadBuffer(end_time);
return ptr < end ? *ptr : -1;
}
int TcpSocket::Peek_()
{
if(!IsOpen() || IsError() || IsEof() || IsAbort())
return -1;
ReadBuffer();
return ptr < end ? *ptr : -1;
return Peek_(GetEndTime());
}
int TcpSocket::Get(void *buffer, int count)
@ -663,8 +704,9 @@ int TcpSocket::Get(void *buffer, int count)
ptr += count;
return count;
}
int end_time = GetEndTime();
while(done < count && !IsError() && !IsEof()) {
if(!WaitRead())
if(!Wait(WAIT_READ, end_time))
break;
int part = Recv((char *)buffer + done, count - done);
if(part > 0)
@ -687,19 +729,41 @@ String TcpSocket::Get(int count)
return out;
}
bool TcpSocket::GetAll(void *buffer, int len)
{
if(Get(buffer, len) == len)
return true;
if(!IsError())
SetSockError("GetAll", -1, "timeout");
return false;
}
String TcpSocket::GetAll(int len)
{
String s = Get(len);
return s.GetCount() == len ? s : String::GetVoid();
if(s.GetCount() != len) {
if(IsEof())
return s;
if(!IsError())
SetSockError("GetAll", -1, "timeout");
return String::GetVoid();
}
return s;
}
String TcpSocket::GetLine(int maxlen)
{
String ln;
int end_time = GetEndTime();
for(;;) {
int c = Peek();
if(c < 0)
int c = Peek(end_time);
if(c < 0) {
if(IsEof())
return ln;
if(!IsError())
SetSockError("GetLine", -1, "timeout");
return String::GetVoid();
}
Get();
if(c == '\n')
return ln;
@ -767,6 +831,13 @@ bool TcpSocket::SSLHandshake()
return false;
}
void TcpSocket::SSLCertificate(const String& cert_, const String& pkey_, bool asn1_)
{
cert = cert_;
pkey = pkey_;
asn1 = asn1_;
}
int SocketWaitEvent::Wait(int timeout)
{
FD_ZERO(read);

View file

@ -125,6 +125,9 @@ dword GetTickCount() {
gettimeofday(tv, tz);
return (dword)tv->tv_sec * 1000 + tv->tv_usec / 1000;
}
int msecs(int from) { return int((GetTickCount() - (dword)from) & 0x7fffffff); }
#endif
void TimeStop::Reset()

View file

@ -9,6 +9,8 @@ static const int _MAX_PATH = PATH_MAX;
dword GetTickCount();
#endif
int msecs(int from = 0);
class TimeStop : Moveable<TimeStop> {
dword starttime;

View file

@ -2,19 +2,6 @@
NAMESPACE_UPP
String WwwFormat(Time tm)
{
static const char *dayofweek[] =
{ "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat" };
static const char *month[] =
{ "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" };
return String().Cat()
<< dayofweek[DayOfWeek(tm)] << ", "
<< (int)tm.day << ' ' << month[tm.month - 1]
<< ' ' << (int)tm.year
<< ' ' << Sprintf("%2d:%02d:%02d +0100", tm.hour, tm.minute, tm.second);
}
bool IsSameTextFile(const char *p, const char *q)
{
for(;;)

View file

@ -1,7 +1,6 @@
#ifndef __tweb_util__
#define __tweb_util__
String WwwFormat(Time tm);
bool IsSameTextFile(const char *p, const char *q);
String StringSample(const char *s, int limit);
String GetRandomIdent(int length);