diff --git a/uppsrc/Core/Core.h b/uppsrc/Core/Core.h index 5bfe2b528..962b4281b 100644 --- a/uppsrc/Core/Core.h +++ b/uppsrc/Core/Core.h @@ -319,7 +319,7 @@ NAMESPACE_UPP #include "LocalProcess.h" -#include "Web.h" +#include "Inet.h" #include "Win32Util.h" diff --git a/uppsrc/Core/Core.upp b/uppsrc/Core/Core.upp index 4e3389cb0..89e7e9728 100644 --- a/uppsrc/Core/Core.upp +++ b/uppsrc/Core/Core.upp @@ -146,8 +146,8 @@ file MD5.cpp, SHA1.cpp, Web readonly separator, - Web.h, - WebUtil.cpp, + Inet.h, + InetUtil.cpp, Socket.cpp, Http.cpp, "Runtime linking" readonly separator, diff --git a/uppsrc/Core/Http.cpp b/uppsrc/Core/Http.cpp index 4c553cbec..135e3dd63 100644 --- a/uppsrc/Core/Http.cpp +++ b/uppsrc/Core/Http.cpp @@ -1,607 +1,683 @@ -#include "Core.h" - -NAMESPACE_UPP - -bool HttpRequest_Trace__; - -#define LLOG(x) do { if(HttpRequest_Trace__) RLOG(x); } while(0) - -#ifdef _DEBUG -_DBG_ -// #define ENDZIP -#endif - -void HttpRequest::Trace(bool b) -{ - HttpRequest_Trace__ = b; -} - -void HttpRequest::Init() -{ - port = 0; - proxy_port = 0; - max_header_size = 1000000; - max_content_size = 10000000; - max_redirects = 5; - max_retries = 3; - force_digest = false; - std_headers = true; - hasurlvar = false; - method = METHOD_GET; - phase = START; - redirect_count = 0; - retry_count = 0; - gzip = false; - WhenContent = callback(this, &HttpRequest::ContentOut); - chunk = 4096; - timeout = 120000; - ssl = false; -} - -HttpRequest::HttpRequest() -{ - Init(); -} - -HttpRequest::HttpRequest(const char *url) -{ - Init(); - Url(url); -} - -HttpRequest& HttpRequest::Url(const char *u) -{ - ssl = memcmp(u, "https", 5) == 0; - const char *t = u; - while(*t && *t != '?') - if(*t++ == '/' && *t == '/') { - u = ++t; - break; - } - t = u; - while(*u && *u != ':' && *u != '/' && *u != '?') - u++; - if(*u == '?' && u[1]) - hasurlvar = true; - host = String(t, u); - port = 0; - if(*u == ':') - port = ScanInt(u + 1, &u); - path = u; - int q = path.Find('#'); - if(q >= 0) - path.Trim(q); - return *this; -} - -HttpRequest& HttpRequest::Proxy(const char *p) -{ - const char *t = p; - while(*p && *p != ':') - p++; - proxy_host = String(t, p); - proxy_port = 80; - if(*p++ == ':' && IsDigit(*p)) - proxy_port = ScanInt(p); - return *this; -} - -HttpRequest& HttpRequest::Post(const char *id, const String& data) -{ - POST(); - if(postdata.GetCount()) - postdata << '&'; - postdata << id << '=' << UrlEncode(data); - return *this; -} - -HttpRequest& HttpRequest::UrlVar(const char *id, const String& data) -{ - int c = *path.Last(); - if(hasurlvar && c != '&') - path << '&'; - if(!hasurlvar && c != '?') - path << '?'; - path << id << '=' << UrlEncode(data); - hasurlvar = true; - return *this; -} - -String HttpRequest::CalculateDigest(const String& authenticate) const -{ - const char *p = authenticate; - String realm, qop, nonce, opaque; - while(*p) { - if(!IsAlNum(*p)) { - p++; - continue; - } - else { - const char *b = p; - while(IsAlNum(*p)) - p++; - String var = ToLower(String(b, p)); - String value; - while(*p && (byte)*p <= ' ') - p++; - if(*p == '=') { - p++; - while(*p && (byte)*p <= ' ') - p++; - if(*p == '\"') { - p++; - while(*p && *p != '\"') - if(*p != '\\' || *++p) - value.Cat(*p++); - if(*p == '\"') - p++; - } - else { - b = p; - while(*p && *p != ',' && (byte)*p > ' ') - p++; - value = String(b, p); - } - } - if(var == "realm") - realm = value; - else if(var == "qop") - qop = value; - else if(var == "nonce") - nonce = value; - else if(var == "opaque") - opaque = value; - } - } - String hv1, hv2; - hv1 << username << ':' << realm << ':' << password; - String ha1 = MD5String(hv1); - hv2 << (method == METHOD_GET ? "GET" : method == METHOD_PUT ? "PUT" : method == METHOD_POST ? "POST" : "READ") - << ':' << path; - String ha2 = MD5String(hv2); - int nc = 1; - String cnonce = FormatIntHex(Random(), 8); - String hv; - hv << ha1 - << ':' << nonce - << ':' << FormatIntHex(nc, 8) - << ':' << cnonce - << ':' << qop << ':' << ha2; - String ha = MD5String(hv); - String auth; - auth << "username=" << AsCString(username) - << ", realm=" << AsCString(realm) - << ", nonce=" << AsCString(nonce) - << ", uri=" << AsCString(path) - << ", qop=" << AsCString(qop) - << ", nc=" << AsCString(FormatIntHex(nc, 8)) - << ", cnonce=" << cnonce - << ", response=" << AsCString(ha); - if(!IsNull(opaque)) - auth << ", opaque=" << AsCString(opaque); - return auth; -} - -HttpRequest& HttpRequest::Header(const char *id, const String& data) -{ - request_headers << id << ": " << data << "\r\n"; - return *this; -} - -void HttpRequest::HttpError(const char *s) -{ - if(IsError()) - return; - error = NFormat(t_("%s:%d: ") + String(s), host, port); - LLOG("HTTP ERROR: " << error); - Close(); -} - -void HttpRequest::StartPhase(int s) -{ - phase = s; - LLOG("Starting status " << s << " '" << GetPhaseName() << "' of " << host); - data.Clear(); -} - -bool HttpRequest::Do() -{ - int c1, c2; - switch(phase) { - case START: - retry_count = 0; - redirect_count = 0; - start_time = msecs(); - Start(); - break; - case DNS: - Dns(); - break; - case REQUEST: - if(SendingData()) - break; - StartPhase(HEADER); - break; - case HEADER: - if(ReadingHeader()) - break; - StartBody(); - break; - case BODY: - if(ReadingBody()) - break; - Finish(); - break; - case CHUNK_HEADER: - ReadingChunkHeader(); - break; - case CHUNK_BODY: - if(ReadingBody()) - break; - c1 = Get(); - c2 = Get(); - if(c1 != '\r' || c2 != '\n') - HttpError("missing ending CRLF in chunked transfer"); - StartPhase(CHUNK_HEADER); - break; - case TRAILER: - if(ReadingHeader()) - break; - header.Parse(data); - Finish(); - break; - case FINISHED: - case FAILED: - return false; - default: - NEVER(); - } - - if(phase != FAILED) - if(IsSocketError() || IsError()) - phase = FAILED; - else - if(msecs() - start_time >= timeout) { - HttpError("connection timed out"); - phase = FAILED; - } - else - if(IsAbort()) { - HttpError("connection was aborted"); - phase = FAILED; - } - - if(phase == FAILED) { - if(retry_count++ < max_retries) { - LLOG("HTTP retry on error " << GetErrorDesc()); - StartRequest(); - } - } - return phase != FINISHED && phase != FAILED; -} - -void HttpRequest::Start() -{ - Close(); - ClearError(); - gzip = false; - z.Clear(); - header.Clear(); - - bool use_proxy = !IsNull(proxy_host); - - int p = use_proxy ? proxy_port : port; - if(!p) - p = ssl ? DEFAULT_HTTPS_PORT : DEFAULT_HTTP_PORT; - String h = use_proxy ? proxy_host : host; - if(IsNull(GetTimeout())) { - addrinfo.Execute(h, p); - StartRequest(); - } - else { - addrinfo.Start(h, p); - StartPhase(DNS); - } -} - -void HttpRequest::Dns() -{ - for(int i = 0; i <= Nvl(GetTimeout(), INT_MAX); i++) { - if(!addrinfo.InProgress()) { - StartRequest(); - return; - } - Sleep(1); - } -} - -void HttpRequest::StartRequest() -{ - if(!Connect(addrinfo)) - return; - - if(ssl && !StartSSL()) - return; - - StartPhase(REQUEST); - count = 0; - String ctype = contenttype; - if((method == METHOD_POST || method == METHOD_PUT) && IsNull(ctype)) - ctype = "application/x-www-form-urlencoded"; - switch(method) { - case METHOD_GET: data << "GET "; break; - case METHOD_POST: data << "POST "; break; - case METHOD_PUT: data << "PUT "; break; - case METHOD_HEAD: data << "HEAD "; break; - default: NEVER(); // invalid method - } - String host_port = host; - if(port) - host_port << ':' << port; - String url; - url << "http://" << host_port << Nvl(path, "/"); - if(!IsNull(proxy_host)) - data << url; - else - data << Nvl(path, "/"); - data << " HTTP/1.1\r\n"; - if(std_headers) { - data// << "URL: " << url << "\r\n" - << "Host: " << host_port << "\r\n" - << "Connection: close\r\n" - << "Accept: " << Nvl(accept, "*/*") << "\r\n" - << "Accept-Encoding: gzip\r\n" - << "User-Agent: " << Nvl(agent, "Ultimate++ HTTP client") << "\r\n"; - if(postdata.GetCount()) - data << "Content-Length: " << postdata.GetCount() << "\r\n"; - if(ctype.GetCount()) - data << "Content-Type: " << ctype << "\r\n"; - } - if(!IsNull(proxy_host) && !IsNull(proxy_username)) - data << "Proxy-Authorization: Basic " << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; - if(!IsNull(digest)) - data << "Authorization: Digest " << digest << "\r\n"; - else - if(!force_digest && (!IsNull(username) || !IsNull(password))) - data << "Authorization: Basic " << Base64Encode(username + ":" + password) << "\r\n"; - data << request_headers << "\r\n" << postdata; // !!! POST PHASE !!! - LLOG("HTTP REQUEST " << host << ":" << port); - LLOG("HTTP request:\n" << data); -} - -bool HttpRequest::SendingData() -{ - for(;;) { - int n = min(2048, data.GetLength() - count); - n = Put(~data + count, n); - if(n == 0) - break; - count += n; - } - return count < data.GetLength(); -} - -bool HttpRequest::ReadingHeader() -{ - for(;;) { - int c = Get(); - if(c < 0) - return !IsEof(); - else - data.Cat(c); - if(data.GetCount() > 3) { - const char *h = data.Last(); - if(h[0] == '\n' && (h[-1] == '\r' && h[-2] == '\n' || h[-1] == '\n')) - return false; - } - if(data.GetCount() > max_header_size) { - HttpError("HTTP header exceeded " + AsString(max_header_size)); - return true; - } - } -} - -void HttpRequest::ReadingChunkHeader() -{ - for(;;) { - int c = Get(); - if(c < 0) - break; - else - if(c == '\n') { - int n = ScanInt(~data, NULL, 16); - LLOG("HTTP Chunk header: 0x" << data << " = " << n); - if(IsNull(n)) { - HttpError("invalid chunk header"); - break; - } - if(n == 0) { - StartPhase(TRAILER); - break; - } - count += n; - StartPhase(CHUNK_BODY); - break; - } - if(c != '\r') - data.Cat(c); - } -} - -String HttpRequest::GetRedirectUrl() -{ - String redirect_url = TrimLeft(header["location"]); - if(redirect_url.StartsWith("http://") || redirect_url.StartsWith("https://")) - return redirect_url; - String h = (ssl ? "https://" : "http://") + host; - if(*redirect_url != '/') - h << '/'; - h << redirect_url; - return h; -} - -int HttpRequest::GetContentLength() -{ - return Nvl(ScanInt(header["content-length"]), -1); -} - -void HttpRequest::StartBody() -{ - LLOG("HTTP Header received: "); - LLOG(data); - header.Clear(); - if(!header.Parse(data)) { - HttpError("invalid HTTP header"); - return; - } - - if(!header.Response(protocol, status_code, reason_phrase)) { - HttpError("invalid HTTP response"); - return; - } - - LLOG("HTTP status code: " << status_code); - - count = GetContentLength(); - - if(count > 0) - body.Reserve(count); - - if(method == METHOD_HEAD) - phase = FINISHED; - else - if(header["transfer-encoding"] == "chunked") { - count = 0; - StartPhase(CHUNK_HEADER); - } - else - StartPhase(BODY); - body.Clear(); - bodylen = 0; - gzip = GetHeader("content-encoding") == "gzip"; - if(gzip) { - gzip = true; - z.WhenOut = callback(this, &HttpRequest::Out); - z.ChunkSize(chunk).GZip().Decompress(); - } -} - -void HttpRequest::ContentOut(const void *ptr, dword size) -{ - body.Cat((const char *)ptr, size); -} - -void HttpRequest::Out(const void *ptr, dword size) -{ - LLOG("HTTP Out " << size); - if(z.IsError()) - HttpError("gzip format error"); - int64 l = bodylen + size; - if(l > max_content_size) { - HttpError("content length exceeded " + AsString(max_content_size)); - phase = FAILED; - return; - } - WhenContent(ptr, size); - bodylen += size; -} - -bool HttpRequest::ReadingBody() -{ - LLOG("HTTP reading data " << count); - int n = chunk; - if(count >= 0) - n = min(n, count); - String s = Get(n); - if(s.GetCount() == 0) - return !IsEof() && count; -#ifndef ENDZIP - if(gzip) - z.Put(~s, s.GetCount()); - else -#endif - Out(~s, s.GetCount()); - if(count >= 0) { - count -= s.GetCount(); - return !IsEof() && count > 0; - } - return !IsEof(); -} - -void HttpRequest::CopyCookies() -{ - int q = header.fields.Find("set-cookie"); - while(q >= 0) { - Cookie(header.fields[q]); - q = header.fields.FindNext(q); - } -} - -void HttpRequest::Finish() -{ - if(gzip) { - #ifdef ENDZIP - body = GZDecompress(body); - if(body.IsVoid()) { - HttpError("gzip decompress at finish error"); - phase = FAILED; - return; - } - #else - z.End(); - if(z.IsError()) { - HttpError("gzip format error (finish)"); - phase = FAILED; - return; - } - #endif - } - Close(); - if(status_code == 401 && !IsNull(username)) { - String authenticate = header["www-authenticate"]; - if(authenticate.GetCount() && redirect_count++ < max_redirects) { - LLOG("HTTP auth digest"); - CopyCookies(); - Digest(CalculateDigest(authenticate)); - Start(); - return; - } - } - if(status_code >= 300 && status_code < 400) { - String url = GetRedirectUrl(); - if(url.GetCount() && redirect_count++ < max_redirects) { - LLOG("HTTP redirect " << url); - Url(url); - CopyCookies(); - Start(); - retry_count = 0; - return; - } - } - phase = FINISHED; -} - -String HttpRequest::Execute() -{ - while(Do()); - return IsSuccess() ? GetContent() : String::GetVoid(); -} - -String HttpRequest::GetPhaseName() const -{ - static const char *m[] = { - "Start", - "Resolving host name", - "Sending request", - "Receiving header", - "Receiving content", - "Receiving chunk header", - "Receiving content chunk", - "Receiving trailer", - "Finished", - "Failed", - }; - return phase >= 0 && phase <= FAILED ? m[phase] : ""; -} - -END_UPP_NAMESPACE +#include "Core.h" + +NAMESPACE_UPP + +bool HttpRequest_Trace__; + +#define LLOG(x) do { if(HttpRequest_Trace__) RLOG(x); } while(0) + +#ifdef _DEBUG +_DBG_ +// #define ENDZIP +#endif + +void HttpRequest::Trace(bool b) +{ + HttpRequest_Trace__ = b; +} + +void HttpRequest::Init() +{ + port = 0; + proxy_port = 0; + ssl_proxy_port = 0; + max_header_size = 1000000; + max_content_size = 10000000; + max_redirects = 5; + max_retries = 3; + force_digest = false; + std_headers = true; + hasurlvar = false; + method = METHOD_GET; + phase = START; + redirect_count = 0; + retry_count = 0; + gzip = false; + WhenContent = callback(this, &HttpRequest::ContentOut); + chunk = 4096; + timeout = 120000; + ssl = false; +} + +HttpRequest::HttpRequest() +{ + Init(); +} + +HttpRequest::HttpRequest(const char *url) +{ + Init(); + Url(url); +} + +HttpRequest& HttpRequest::Url(const char *u) +{ + ssl = memcmp(u, "https", 5) == 0; + const char *t = u; + while(*t && *t != '?') + if(*t++ == '/' && *t == '/') { + u = ++t; + break; + } + t = u; + while(*u && *u != ':' && *u != '/' && *u != '?') + u++; + if(*u == '?' && u[1]) + hasurlvar = true; + host = String(t, u); + port = 0; + if(*u == ':') + port = ScanInt(u + 1, &u); + path = u; + int q = path.Find('#'); + if(q >= 0) + path.Trim(q); + return *this; +} + +static +void sParseProxyUrl(const char *p, String& proxy_host, int proxy_port) +{ + const char *t = p; + while(*p && *p != ':') + p++; + proxy_host = String(t, p); + if(*p++ == ':' && IsDigit(*p)) + proxy_port = ScanInt(p); +} + +HttpRequest& HttpRequest::Proxy(const char *url) +{ + proxy_port = 80; + sParseProxyUrl(url, proxy_host, proxy_port); + return *this; +} + +HttpRequest& HttpRequest::SSLProxy(const char *url) +{ + ssl_proxy_port = 8080; + sParseProxyUrl(url, ssl_proxy_host, ssl_proxy_port); + return *this; +} + +HttpRequest& HttpRequest::Post(const char *id, const String& data) +{ + POST(); + if(postdata.GetCount()) + postdata << '&'; + postdata << id << '=' << UrlEncode(data); + return *this; +} + +HttpRequest& HttpRequest::UrlVar(const char *id, const String& data) +{ + int c = *path.Last(); + if(hasurlvar && c != '&') + path << '&'; + if(!hasurlvar && c != '?') + path << '?'; + path << id << '=' << UrlEncode(data); + hasurlvar = true; + return *this; +} + +String HttpRequest::CalculateDigest(const String& authenticate) const +{ + const char *p = authenticate; + String realm, qop, nonce, opaque; + while(*p) { + if(!IsAlNum(*p)) { + p++; + continue; + } + else { + const char *b = p; + while(IsAlNum(*p)) + p++; + String var = ToLower(String(b, p)); + String value; + while(*p && (byte)*p <= ' ') + p++; + if(*p == '=') { + p++; + while(*p && (byte)*p <= ' ') + p++; + if(*p == '\"') { + p++; + while(*p && *p != '\"') + if(*p != '\\' || *++p) + value.Cat(*p++); + if(*p == '\"') + p++; + } + else { + b = p; + while(*p && *p != ',' && (byte)*p > ' ') + p++; + value = String(b, p); + } + } + if(var == "realm") + realm = value; + else if(var == "qop") + qop = value; + else if(var == "nonce") + nonce = value; + else if(var == "opaque") + opaque = value; + } + } + String hv1, hv2; + hv1 << username << ':' << realm << ':' << password; + String ha1 = MD5String(hv1); + hv2 << (method == METHOD_GET ? "GET" : method == METHOD_PUT ? "PUT" : method == METHOD_POST ? "POST" : "READ") + << ':' << path; + String ha2 = MD5String(hv2); + int nc = 1; + String cnonce = FormatIntHex(Random(), 8); + String hv; + hv << ha1 + << ':' << nonce + << ':' << FormatIntHex(nc, 8) + << ':' << cnonce + << ':' << qop << ':' << ha2; + String ha = MD5String(hv); + String auth; + auth << "username=" << AsCString(username) + << ", realm=" << AsCString(realm) + << ", nonce=" << AsCString(nonce) + << ", uri=" << AsCString(path) + << ", qop=" << AsCString(qop) + << ", nc=" << AsCString(FormatIntHex(nc, 8)) + << ", cnonce=" << cnonce + << ", response=" << AsCString(ha); + if(!IsNull(opaque)) + auth << ", opaque=" << AsCString(opaque); + return auth; +} + +HttpRequest& HttpRequest::Header(const char *id, const String& data) +{ + request_headers << id << ": " << data << "\r\n"; + return *this; +} + +void HttpRequest::HttpError(const char *s) +{ + if(IsError()) + return; + error = NFormat(t_("%s:%d: ") + String(s), host, port); + LLOG("HTTP ERROR: " << error); + Close(); +} + +void HttpRequest::StartPhase(int s) +{ + phase = s; + LLOG("Starting status " << s << " '" << GetPhaseName() << "', url: " << host); + data.Clear(); +} + +bool HttpRequest::Do() +{ + int c1, c2; + switch(phase) { + case START: + retry_count = 0; + redirect_count = 0; + start_time = msecs(); + Start(); + break; + case DNS: + Dns(); + break; + case SSLPROXYREQUEST: + if(SendingData()) + break; + StartPhase(SSLPROXYRESPONSE); + break; + case SSLPROXYRESPONSE: + if(ReadingHeader()) + break; + ProcessSSLProxyResponse(); + break; + case SSLHANDSHAKE: + if(SSLHandshake()) + break; + StartRequest(); + break; + case REQUEST: + if(SendingData()) + break; + StartPhase(HEADER); + break; + case HEADER: + if(ReadingHeader()) + break; + StartBody(); + break; + case BODY: + if(ReadingBody()) + break; + Finish(); + break; + case CHUNK_HEADER: + ReadingChunkHeader(); + break; + case CHUNK_BODY: + if(ReadingBody()) + break; + c1 = Get(); + c2 = Get(); + if(c1 != '\r' || c2 != '\n') + HttpError("missing ending CRLF in chunked transfer"); + StartPhase(CHUNK_HEADER); + break; + case TRAILER: + if(ReadingHeader()) + break; + header.Parse(data); + Finish(); + break; + case FINISHED: + case FAILED: + return false; + default: + NEVER(); + } + + if(phase != FAILED) + if(IsSocketError() || IsError()) + phase = FAILED; + else + if(msecs() - start_time >= timeout) { + HttpError("connection timed out"); + phase = FAILED; + } + else + if(IsAbort()) { + HttpError("connection was aborted"); + phase = FAILED; + } + + if(phase == FAILED) { + if(retry_count++ < max_retries) { + LLOG("HTTP retry on error " << GetErrorDesc()); + start_time = msecs(); + Start(); + } + } + return phase != FINISHED && phase != FAILED; +} + +void HttpRequest::Start() +{ + Close(); + ClearError(); + gzip = false; + z.Clear(); + header.Clear(); + + bool use_proxy = !IsNull(ssl ? ssl_proxy_host : proxy_host); + + int p = use_proxy ? (ssl ? ssl_proxy_port : proxy_port) : port; + if(!p) + p = ssl ? DEFAULT_HTTPS_PORT : DEFAULT_HTTP_PORT; + String h = use_proxy ? ssl ? ssl_proxy_host : proxy_host : host; + + if(IsNull(GetTimeout())) { + addrinfo.Execute(h, p); + StartConnect(); + } + else { + addrinfo.Start(h, p); + StartPhase(DNS); + } +} + +void HttpRequest::Dns() +{ + for(int i = 0; i <= Nvl(GetTimeout(), INT_MAX); i++) { + if(!addrinfo.InProgress()) { + StartConnect(); + return; + } + Sleep(1); + } +} + +void HttpRequest::StartConnect() +{ + if(!Connect(addrinfo)) + return; + if(ssl && ssl_proxy_host.GetCount()) { + StartPhase(SSLPROXYREQUEST); + String host_port = host; + if(port) + host_port << ':' << port; + else + host_port << ":443"; + data << "CONNECT " << host_port << " HTTP/1.1\r\n" + << "Host: " << host_port << "\r\n"; + if(!IsNull(ssl_proxy_username)) + data << "Proxy-Authorization: Basic " + << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; + data << "\r\n"; + count = 0; + LLOG("HTTPS proxy request:\n" << data); + } + else + AfterConnect(); +} + +void HttpRequest::ProcessSSLProxyResponse() +{ + LLOG("HTTPS proxy response:\n" << data); + int q = min(data.Find('\r'), data.Find('\n')); + if(q >= 0) + data.Trim(q); + if(!data.StartsWith("HTTP") || data.Find(" 2") < 0) { + HttpError("Invalid proxy reply: " + data); + return; + } + AfterConnect(); +} + +void HttpRequest::AfterConnect() +{ + if(ssl && !StartSSL()) + return; + if(ssl) + StartPhase(SSLHANDSHAKE); + else + StartRequest(); +} + +void HttpRequest::StartRequest() +{ + StartPhase(REQUEST); + count = 0; + String ctype = contenttype; + if((method == METHOD_POST || method == METHOD_PUT) && IsNull(ctype)) + ctype = "application/x-www-form-urlencoded"; + switch(method) { + case METHOD_GET: data << "GET "; break; + case METHOD_POST: data << "POST "; break; + case METHOD_PUT: data << "PUT "; break; + case METHOD_HEAD: data << "HEAD "; break; + default: NEVER(); // invalid method + } + String host_port = host; + if(port) + host_port << ':' << port; + String url; + url << "http://" << host_port << Nvl(path, "/"); + if(!IsNull(proxy_host) && !ssl) + data << url; + else + data << Nvl(path, "/"); + data << " HTTP/1.1\r\n"; + if(std_headers) { + data << "URL: " << url << "\r\n" + << "Host: " << host_port << "\r\n" + << "Connection: close\r\n" + << "Accept: " << Nvl(accept, "*/*") << "\r\n" + << "Accept-Encoding: gzip\r\n" + << "User-Agent: " << Nvl(agent, "Ultimate++ HTTP client") << "\r\n"; + if(postdata.GetCount()) + data << "Content-Length: " << postdata.GetCount() << "\r\n"; + if(ctype.GetCount()) + data << "Content-Type: " << ctype << "\r\n"; + } + if(!IsNull(proxy_host) && !IsNull(proxy_username)) + data << "Proxy-Authorization: Basic " << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; + if(!IsNull(digest)) + data << "Authorization: Digest " << digest << "\r\n"; + else + if(!force_digest && (!IsNull(username) || !IsNull(password))) + data << "Authorization: Basic " << Base64Encode(username + ":" + password) << "\r\n"; + data << request_headers << "\r\n" << postdata; // !!! POST PHASE !!! + LLOG("HTTP REQUEST " << host << ":" << port); + LLOG("HTTP request:\n" << data); +} + +bool HttpRequest::SendingData() +{ + for(;;) { + int n = min(2048, data.GetLength() - count); + n = Put(~data + count, n); + if(n == 0) + break; + count += n; + } + return count < data.GetLength(); +} + +bool HttpRequest::ReadingHeader() +{ + for(;;) { + int c = Get(); + if(c < 0) + return !IsEof(); + else + data.Cat(c); + if(data.GetCount() > 3) { + const char *h = data.Last(); + if(h[0] == '\n' && (h[-1] == '\r' && h[-2] == '\n' || h[-1] == '\n')) + return false; + } + if(data.GetCount() > max_header_size) { + HttpError("HTTP header exceeded " + AsString(max_header_size)); + return true; + } + } +} + +void HttpRequest::ReadingChunkHeader() +{ + for(;;) { + int c = Get(); + if(c < 0) + break; + else + if(c == '\n') { + int n = ScanInt(~data, NULL, 16); + LLOG("HTTP Chunk header: 0x" << data << " = " << n); + if(IsNull(n)) { + HttpError("invalid chunk header"); + break; + } + if(n == 0) { + StartPhase(TRAILER); + break; + } + count += n; + StartPhase(CHUNK_BODY); + break; + } + if(c != '\r') + data.Cat(c); + } +} + +String HttpRequest::GetRedirectUrl() +{ + String redirect_url = TrimLeft(header["location"]); + if(redirect_url.StartsWith("http://") || redirect_url.StartsWith("https://")) + return redirect_url; + String h = (ssl ? "https://" : "http://") + host; + if(*redirect_url != '/') + h << '/'; + h << redirect_url; + return h; +} + +int HttpRequest::GetContentLength() +{ + return Nvl(ScanInt(header["content-length"]), -1); +} + +void HttpRequest::StartBody() +{ + LLOG("HTTP Header received: "); + LLOG(data); + header.Clear(); + if(!header.Parse(data)) { + HttpError("invalid HTTP header"); + return; + } + + if(!header.Response(protocol, status_code, reason_phrase)) { + HttpError("invalid HTTP response"); + return; + } + + LLOG("HTTP status code: " << status_code); + + count = GetContentLength(); + + if(count > 0) + body.Reserve(count); + + if(method == METHOD_HEAD) + phase = FINISHED; + else + if(header["transfer-encoding"] == "chunked") { + count = 0; + StartPhase(CHUNK_HEADER); + } + else + StartPhase(BODY); + body.Clear(); + bodylen = 0; + gzip = GetHeader("content-encoding") == "gzip"; + if(gzip) { + gzip = true; + z.WhenOut = callback(this, &HttpRequest::Out); + z.ChunkSize(chunk).GZip().Decompress(); + } +} + +void HttpRequest::ContentOut(const void *ptr, dword size) +{ + body.Cat((const char *)ptr, size); +} + +void HttpRequest::Out(const void *ptr, dword size) +{ + LLOG("HTTP Out " << size); + if(z.IsError()) + HttpError("gzip format error"); + int64 l = bodylen + size; + if(l > max_content_size) { + HttpError("content length exceeded " + AsString(max_content_size)); + phase = FAILED; + return; + } + WhenContent(ptr, size); + bodylen += size; +} + +bool HttpRequest::ReadingBody() +{ + LLOG("HTTP reading data " << count); + int n = chunk; + if(count >= 0) + n = min(n, count); + String s = Get(n); + if(s.GetCount() == 0) + return !IsEof() && count; +#ifndef ENDZIP + if(gzip) + z.Put(~s, s.GetCount()); + else +#endif + Out(~s, s.GetCount()); + if(count >= 0) { + count -= s.GetCount(); + return !IsEof() && count > 0; + } + return !IsEof(); +} + +void HttpRequest::CopyCookies() +{ + int q = header.fields.Find("set-cookie"); + while(q >= 0) { + Cookie(header.fields[q]); + q = header.fields.FindNext(q); + } +} + +void HttpRequest::Finish() +{ + if(gzip) { + #ifdef ENDZIP + body = GZDecompress(body); + if(body.IsVoid()) { + HttpError("gzip decompress at finish error"); + phase = FAILED; + return; + } + #else + z.End(); + if(z.IsError()) { + HttpError("gzip format error (finish)"); + phase = FAILED; + return; + } + #endif + } + Close(); + if(status_code == 401 && !IsNull(username)) { + String authenticate = header["www-authenticate"]; + if(authenticate.GetCount() && redirect_count++ < max_redirects) { + LLOG("HTTP auth digest"); + CopyCookies(); + Digest(CalculateDigest(authenticate)); + Start(); + return; + } + } + if(status_code >= 300 && status_code < 400) { + String url = GetRedirectUrl(); + if(url.GetCount() && redirect_count++ < max_redirects) { + LLOG("HTTP redirect " << url); + Url(url); + CopyCookies(); + Start(); + retry_count = 0; + return; + } + } + phase = FINISHED; +} + +String HttpRequest::Execute() +{ + while(Do()) + LLOG("HTTP Execute: " << GetPhaseName()); + return IsSuccess() ? GetContent() : String::GetVoid(); +} + +String HttpRequest::GetPhaseName() const +{ + static const char *m[] = { + "Start", + "Resolving host name", + "SSL proxy request", + "SSL proxy response", + "SSL handshake", + "Sending request", + "Receiving header", + "Receiving content", + "Receiving chunk header", + "Receiving content chunk", + "Receiving trailer", + "Finished", + "Failed", + }; + return phase >= 0 && phase <= FAILED ? m[phase] : ""; +} + +END_UPP_NAMESPACE diff --git a/uppsrc/Core/Web.h b/uppsrc/Core/Inet.h similarity index 90% rename from uppsrc/Core/Web.h rename to uppsrc/Core/Inet.h index 1c968cae2..ff284d734 100644 --- a/uppsrc/Core/Web.h +++ b/uppsrc/Core/Inet.h @@ -1,376 +1,407 @@ -String FormatIP(dword _ip); - -String UrlEncode(const String& s); -String UrlEncode(const String& s, const char *specials); -String UrlDecode(const char *b, const char *e); -inline String UrlDecode(const String& s) { return UrlDecode(s.Begin(), s.End() ); } - -String Base64Encode(const char *b, const char *e); -inline String Base64Encode(const String& data) { return Base64Encode(data.Begin(), data.End()); } -String Base64Decode(const char *b, const char *e); -inline String Base64Decode(const String& data) { return Base64Decode(data.Begin(), data.End()); } - -class IpAddrInfo { - enum { COUNT = 32 }; - struct Entry { - const char *host; - const char *port; - int status; - addrinfo *addr; - }; - static Entry pool[COUNT]; - - enum { - EMPTY = 0, WORKING, CANCELED, RESOLVED, FAILED - }; - - String host, port; - Entry *entry; - Entry exe[1]; - - static void EnterPool(); - static void LeavePool(); - static rawthread_t rawthread__ Thread(void *ptr); - - void Start(); - -public: - void Start(const String& host, int port); - bool InProgress(); - bool Execute(const String& host, int port); - addrinfo *GetResult(); - void Clear(); - - IpAddrInfo(); - ~IpAddrInfo() { Clear(); } -}; - -enum { WAIT_READ = 1, WAIT_WRITE = 2, WAIT_EXCEPTION = 4, WAIT_ALL = 7 }; - -class TcpSocket { - enum { BUFFERSIZE = 512 }; - enum { NONE, CONNECT, ACCEPT }; - SOCKET socket; - int mode; - char buffer[BUFFERSIZE]; - char *ptr; - char *end; - bool is_eof; - bool is_error; - bool is_abort; - bool ipv6; - - int timeout; - int waitstep; - int done; - - int errorcode; - String errordesc; - - struct SSL { - virtual bool Start() = 0; - virtual bool Wait(dword flags) = 0; - virtual int Send(const void *buffer, int maxlen) = 0; - virtual int Recv(void *buffer, int maxlen) = 0; - virtual void Close() = 0; - - virtual ~SSL() {} - }; - - One ssl; - - struct SSLImp; - friend struct SSLImp; - - static SSL *(*CreateSSL)(TcpSocket& socket); - static SSL *CreateSSLImp(TcpSocket& socket); - - friend void InitCreateSSL(); - - bool RawWait(dword flags); - SOCKET AcceptRaw(dword *ipaddr, int timeout_msec); - bool Open(int family, int type, int protocol); - int RawRecv(void *buffer, int maxlen); - int Recv(void *buffer, int maxlen); - int RawSend(const void *buffer, int maxlen); - int Send(const void *buffer, int maxlen); - bool RawConnect(addrinfo *info); - void RawClose(); - - void ReadBuffer(); - int Get_(); - int Peek_(); - - void Reset(); - - void SetSockError(const char *context, int code, const char *errdesc); - void SetSockError(const char *context, const char *errdesc); - void SetSockError(const char *context); - - static int GetErrorCode(); - static bool WouldBlock(); - -public: - Callback WhenWait; - - static String GetHostName(); - - int GetDone() const { return done; } - - static void Init(); - - bool IsOpen() const { return socket != INVALID_SOCKET; } - bool IsEof() const { return is_eof && ptr == end; } - - bool IsError() const { return is_error; } - void ClearError() { is_error = false; errorcode = 0; errordesc.Clear(); } - int GetError() const { return errorcode; } - String GetErrorDesc() const { return errordesc; } - - void Abort() { is_abort = true; } - bool IsAbort() const { return is_abort; } - void ClearAbort() { is_abort = false; } - - SOCKET GetSOCKET() const { return socket; } - String GetPeerAddr() const; - - void Attach(SOCKET socket); - bool Connect(const char *host, int port); - bool Connect(IpAddrInfo& info); - bool Listen(int port, int listen_count, bool ipv6 = false, bool reuse = true); - bool Accept(TcpSocket& listen_socket); - void Close(); - void Shutdown(); - - void NoDelay(); - void Linger(int msecs); - void NoLinger() { Linger(Null); } - void Reuse(bool reuse = true); - - bool Wait(dword events); - bool WaitRead() { return Wait(WAIT_READ); } - bool WaitWrite() { return Wait(WAIT_WRITE); } - - int Peek() { return ptr < end ? *ptr : Peek_(); } - int Term() { return Peek(); } - int Get() { return ptr < end ? *ptr++ : Get_(); } - int Get(void *buffer, int len); - String Get(int len); - int GetAll(void *buffer, int len) { return Get(buffer, len) == len; } - String GetAll(int len) { String s = Get(len); return s.GetCount() == len ? s : String::GetVoid(); } - String GetLine(int maxlen = 2000000); - - int Put(const char *s, int len); - int Put(const String& s) { return Put(s.Begin(), s.GetLength()); } - bool PutAll(const char *s, int len) { return Put(s, len) == len; } - bool PutAll(const String& s) { return Put(s) == s.GetCount(); } - - bool StartSSL(); - bool IsSSL() const { return ssl; } - - TcpSocket& Timeout(int ms) { timeout = ms; return *this; } - int GetTimeout() const { return timeout; } - TcpSocket& Blocking() { return Timeout(Null); } - - TcpSocket(); - ~TcpSocket() { Close(); } -}; - -class SocketWaitEvent { - Vector< Tuple2 > socket; - fd_set read[1], write[1], exception[1]; - -public: - void Clear() { socket.Clear(); } - void Add(SOCKET s, dword events = WAIT_ALL) { socket.Add(MakeTuple((int)s, events)); } - void Add(TcpSocket& s, dword events = WAIT_ALL) { Add(s.GetSOCKET(), events); } - int Wait(int timeout); - dword Get(int i) const; - dword operator[](int i) const { return Get(i); } - - SocketWaitEvent(); -}; - -struct HttpHeader { - String first_line; - VectorMap fields; - - String operator[](const char *id) { return fields.Get(id, Null); } - - bool Response(String& protocol, int& code, String& reason); - bool Request(String& method, String& uri, String& version); - - void Clear(); - bool Parse(const String& hdrs); -}; - -class HttpRequest : public TcpSocket { - int phase; - String data; - int count; - - HttpHeader header; - - String error; - String body; - - enum { - DEFAULT_HTTP_PORT = 80, - DEFAULT_HTTPS_PORT = 443 - }; - - enum { - METHOD_GET, - METHOD_POST, - METHOD_HEAD, - METHOD_PUT, - }; - - int max_header_size; - int max_content_size; - int max_redirects; - int max_retries; - int timeout; - - String host; - int port; - String proxy_host; - int proxy_port; - String proxy_username; - String proxy_password; - String path; - bool ssl; - - int method; - String accept; - String agent; - bool force_digest; - bool is_post; - bool std_headers; - bool hasurlvar; - String contenttype; - String username; - String password; - String digest; - String request_headers; - String postdata; - - String protocol; - int status_code; - String reason_phrase; - - int start_time; - int retry_count; - int redirect_count; - - int chunk; - - IpAddrInfo addrinfo; - int bodylen; - bool gzip; - Zlib z; - - void Init(); - - void StartPhase(int s); - void Start(); - void Dns(); - void StartRequest(); - bool SendingData(); - bool ReadingHeader(); - void StartBody(); - bool ReadingBody(); - void ReadingChunkHeader(); - void Finish(); - - void CopyCookies(); - - void HttpError(const char *s); - void ContentOut(const void *ptr, dword size); - void Out(const void *ptr, dword size); - - String CalculateDigest(const String& authenticate) const; - -public: - Callback2 WhenContent; - - HttpRequest& MaxHeaderSize(int m) { max_header_size = m; return *this; } - HttpRequest& MaxContentSize(int m) { max_content_size = m; return *this; } - HttpRequest& MaxRedirect(int n) { max_redirects = n; return *this; } - HttpRequest& MaxRetries(int n) { max_retries = n; return *this; } - HttpRequest& RequestTimeout(int ms) { timeout = ms; return *this; } - HttpRequest& ChunkSize(int n) { chunk = n; return *this; } - - HttpRequest& Method(int m) { method = m; return *this; } - HttpRequest& GET() { return Method(METHOD_GET); } - HttpRequest& POST() { return Method(METHOD_POST); } - HttpRequest& HEAD() { return Method(METHOD_HEAD); } - HttpRequest& PUT() { return Method(METHOD_PUT); } - - HttpRequest& Host(const String& h) { host = h; return *this; } - HttpRequest& Port(int p) { port = p; return *this; } - HttpRequest& SSL(bool b = true) { ssl = b; return *this; } - HttpRequest& Path(const String& p) { path = p; return *this; } - HttpRequest& User(const String& u, const String& p) { username = u; password = p; return *this; } - HttpRequest& Digest() { force_digest = true; return *this; } - HttpRequest& Digest(const String& d) { digest = d; return *this; } - HttpRequest& Url(const char *url); - HttpRequest& UrlVar(const char *id, const String& data); - HttpRequest& operator()(const char *id, const String& data) { return UrlVar(id, data); } - HttpRequest& PostData(const String& pd) { postdata = pd; return *this; } - HttpRequest& PostUData(const String& pd) { return PostData(UrlEncode(pd)); } - HttpRequest& Post(const String& data) { POST(); return PostData(data); } - HttpRequest& Post(const char *id, const String& data); - - HttpRequest& Headers(const String& h) { request_headers = h; return *this; } - HttpRequest& ClearHeaders() { return Headers(Null); } - HttpRequest& AddHeaders(const String& h) { request_headers.Cat(h); return *this; } - HttpRequest& Header(const char *id, const String& data); - HttpRequest& Cookie(const String& cookie) { return Header("Cookie", cookie); } - - HttpRequest& StdHeaders(bool sh) { std_headers = sh; return *this; } - HttpRequest& NoStdHeaders() { return StdHeaders(false); } - HttpRequest& Accept(const String& a) { accept = a; return *this; } - HttpRequest& UserAgent(const String& a) { agent = a; return *this; } - HttpRequest& ContentType(const String& a) { contenttype = a; return *this; } - - HttpRequest& Proxy(const String& host, int port) { proxy_host = host; proxy_port = port; return *this; } - HttpRequest& Proxy(const char *url); - HttpRequest& ProxyAuth(const String& u, const String& p) { proxy_username = u; proxy_password = p; return *this; } - - bool IsSocketError() const { return TcpSocket::IsError(); } - bool IsHttpError() const { return !IsNull(error) ; } - bool IsError() const { return IsSocketError() || IsHttpError(); } - String GetErrorDesc() const { return IsSocketError() ? TcpSocket::GetErrorDesc() : error; } - void ClearError() { TcpSocket::ClearError(); error.Clear(); } - - String GetHeader(const char *s) { return header[s]; } - String operator[](const char *s) { return GetHeader(s); } - String GetRedirectUrl(); - int GetContentLength(); - int GetStatusCode() const { return status_code; } - String GetReasonPhrase() const { return reason_phrase; } - - String GetContent() const { return body; } - String operator~() const { return GetContent(); } - operator String() const { return GetContent(); } - void ClearContent() { body.Clear(); } - - enum Phase { - START, DNS, REQUEST, HEADER, BODY, CHUNK_HEADER, CHUNK_BODY, TRAILER, FINISHED, FAILED - }; - - bool Do(); - int GetPhase() const { return phase; } - String GetPhaseName() const; - bool InProgress() const { return phase != FAILED && phase != FINISHED; } - bool IsFailure() const { return phase == FAILED; } - bool IsSuccess() const { return phase == FINISHED && status_code >= 200 && status_code < 300; } - - String Execute(); - - HttpRequest(); - HttpRequest(const char *url); - - static void Trace(bool b = true); -}; +String FormatIP(dword _ip); + +String UrlEncode(const String& s); +String UrlEncode(const String& s, const char *specials); +String UrlDecode(const char *b, const char *e); +inline String UrlDecode(const String& s) { return UrlDecode(s.Begin(), s.End() ); } + +String Base64Encode(const char *b, const char *e); +inline String Base64Encode(const String& data) { return Base64Encode(data.Begin(), data.End()); } +String Base64Decode(const char *b, const char *e); +inline String Base64Decode(const String& data) { return Base64Decode(data.Begin(), data.End()); } + +class IpAddrInfo { + enum { COUNT = 32 }; + struct Entry { + const char *host; + const char *port; + int status; + addrinfo *addr; + }; + static Entry pool[COUNT]; + + enum { + EMPTY = 0, WORKING, CANCELED, RESOLVED, FAILED + }; + + String host, port; + Entry *entry; + Entry exe[1]; + + static void EnterPool(); + static void LeavePool(); + static rawthread_t rawthread__ Thread(void *ptr); + + void Start(); + +public: + void Start(const String& host, int port); + bool InProgress(); + bool Execute(const String& host, int port); + addrinfo *GetResult(); + void Clear(); + + IpAddrInfo(); + ~IpAddrInfo() { Clear(); } +}; + +enum { WAIT_READ = 1, WAIT_WRITE = 2, WAIT_EXCEPTION = 4, WAIT_ALL = 7 }; + +struct SSLInfo { + String cipher; + bool cert_avail; + bool cert_verified; // Peer verification not yet working - this is always false + String cert_subject; + String cert_issuer; + Date cert_notbefore; + Date cert_notafter; + int cert_version; + String cert_serial; +}; + +class TcpSocket { + enum { BUFFERSIZE = 512 }; + enum { NONE, CONNECT, ACCEPT, SSL_CONNECTED }; + SOCKET socket; + int mode; + char buffer[BUFFERSIZE]; + char *ptr; + char *end; + bool is_eof; + bool is_error; + bool is_abort; + bool ipv6; + + int timeout; + int waitstep; + int done; + + int errorcode; + String errordesc; + + struct SSL { + virtual bool Start() = 0; + virtual bool Wait(dword flags) = 0; + virtual int Send(const void *buffer, int maxlen) = 0; + virtual int Recv(void *buffer, int maxlen) = 0; + virtual void Close() = 0; + virtual dword Handshake() = 0; + + virtual ~SSL() {} + }; + + One ssl; + One sslinfo; + + struct SSLImp; + friend struct SSLImp; + + static SSL *(*CreateSSL)(TcpSocket& socket); + static SSL *CreateSSLImp(TcpSocket& socket); + + friend void InitCreateSSL(); + + bool RawWait(dword flags); + SOCKET AcceptRaw(dword *ipaddr, int timeout_msec); + bool Open(int family, int type, int protocol); + int RawRecv(void *buffer, int maxlen); + int Recv(void *buffer, int maxlen); + int RawSend(const void *buffer, int maxlen); + int Send(const void *buffer, int maxlen); + bool RawConnect(addrinfo *info); + void RawClose(); + + void ReadBuffer(); + int Get_(); + int Peek_(); + + void Reset(); + + void SetSockError(const char *context, int code, const char *errdesc); + void SetSockError(const char *context, const char *errdesc); + void SetSockError(const char *context); + + static int GetErrorCode(); + static bool WouldBlock(); + +public: + Callback WhenWait; + + static String GetHostName(); + + int GetDone() const { return done; } + + static void Init(); + + bool IsOpen() const { return socket != INVALID_SOCKET; } + bool IsEof() const { return is_eof && ptr == end; } + + bool IsError() const { return is_error; } + void ClearError() { is_error = false; errorcode = 0; errordesc.Clear(); } + int GetError() const { return errorcode; } + String GetErrorDesc() const { return errordesc; } + + void Abort() { is_abort = true; } + bool IsAbort() const { return is_abort; } + void ClearAbort() { is_abort = false; } + + SOCKET GetSOCKET() const { return socket; } + String GetPeerAddr() const; + + void Attach(SOCKET socket); + bool Connect(const char *host, int port); + bool Connect(IpAddrInfo& info); + bool Listen(int port, int listen_count, bool ipv6 = false, bool reuse = true); + bool Accept(TcpSocket& listen_socket); + void Close(); + void Shutdown(); + + void NoDelay(); + void Linger(int msecs); + void NoLinger() { Linger(Null); } + void Reuse(bool reuse = true); + + bool Wait(dword events); + bool WaitRead() { return Wait(WAIT_READ); } + bool WaitWrite() { return Wait(WAIT_WRITE); } + + int Peek() { return ptr < end ? *ptr : Peek_(); } + int Term() { return Peek(); } + int Get() { return ptr < end ? *ptr++ : Get_(); } + int Get(void *buffer, int len); + String Get(int len); + int GetAll(void *buffer, int len) { return Get(buffer, len) == len; } + String GetAll(int len); + String GetLine(int maxlen = 2000000); + + int Put(const char *s, int len); + int Put(const String& s) { return Put(s.Begin(), s.GetLength()); } + bool PutAll(const char *s, int len) { return Put(s, len) == len; } + bool PutAll(const String& s) { return Put(s) == s.GetCount(); } + + bool StartSSL(); + bool IsSSL() const { return ssl; } + bool SSLHandshake(); + const SSLInfo *GetSSLInfo() const { return ~sslinfo; } + + TcpSocket& Timeout(int ms) { timeout = ms; return *this; } + int GetTimeout() const { return timeout; } + TcpSocket& Blocking() { return Timeout(Null); } + + TcpSocket(); + ~TcpSocket() { Close(); } +}; + +class SocketWaitEvent { + Vector< Tuple2 > socket; + fd_set read[1], write[1], exception[1]; + +public: + void Clear() { socket.Clear(); } + void Add(SOCKET s, dword events = WAIT_ALL) { socket.Add(MakeTuple((int)s, events)); } + void Add(TcpSocket& s, dword events = WAIT_ALL) { Add(s.GetSOCKET(), events); } + int Wait(int timeout); + dword Get(int i) const; + dword operator[](int i) const { return Get(i); } + + SocketWaitEvent(); +}; + +struct HttpHeader { + String first_line; + VectorMap fields; + + String operator[](const char *id) { return fields.Get(id, Null); } + + bool Response(String& protocol, int& code, String& reason); + bool Request(String& method, String& uri, String& version); + + void Clear(); + bool Parse(const String& hdrs); +}; + +class HttpRequest : public TcpSocket { + int phase; + String data; + int count; + + HttpHeader header; + + String error; + String body; + + enum { + DEFAULT_HTTP_PORT = 80, + DEFAULT_HTTPS_PORT = 443 + }; + + enum { + METHOD_GET, + METHOD_POST, + METHOD_HEAD, + METHOD_PUT, + }; + + int max_header_size; + int max_content_size; + int max_redirects; + int max_retries; + int timeout; + + String host; + int port; + String proxy_host; + int proxy_port; + String proxy_username; + String proxy_password; + String ssl_proxy_host; + int ssl_proxy_port; + String ssl_proxy_username; + String ssl_proxy_password; + String path; + bool ssl; + + int method; + String accept; + String agent; + bool force_digest; + bool is_post; + bool std_headers; + bool hasurlvar; + String contenttype; + String username; + String password; + String digest; + String request_headers; + String postdata; + + String protocol; + int status_code; + String reason_phrase; + + int start_time; + int retry_count; + int redirect_count; + + int chunk; + + IpAddrInfo addrinfo; + int bodylen; + bool gzip; + Zlib z; + + void Init(); + + void StartPhase(int s); + void Start(); + void Dns(); + void StartConnect(); + void ProcessSSLProxyResponse(); + void AfterConnect(); + void StartRequest(); + bool SendingData(); + bool ReadingHeader(); + void StartBody(); + bool ReadingBody(); + void ReadingChunkHeader(); + void Finish(); + + void CopyCookies(); + + void HttpError(const char *s); + void ContentOut(const void *ptr, dword size); + void Out(const void *ptr, dword size); + + String CalculateDigest(const String& authenticate) const; + +public: + Callback2 WhenContent; + + HttpRequest& MaxHeaderSize(int m) { max_header_size = m; return *this; } + HttpRequest& MaxContentSize(int m) { max_content_size = m; return *this; } + HttpRequest& MaxRedirect(int n) { max_redirects = n; return *this; } + HttpRequest& MaxRetries(int n) { max_retries = n; return *this; } + HttpRequest& RequestTimeout(int ms) { timeout = ms; return *this; } + HttpRequest& ChunkSize(int n) { chunk = n; return *this; } + + HttpRequest& Method(int m) { method = m; return *this; } + HttpRequest& GET() { return Method(METHOD_GET); } + HttpRequest& POST() { return Method(METHOD_POST); } + HttpRequest& HEAD() { return Method(METHOD_HEAD); } + HttpRequest& PUT() { return Method(METHOD_PUT); } + + HttpRequest& Host(const String& h) { host = h; return *this; } + HttpRequest& Port(int p) { port = p; return *this; } + HttpRequest& SSL(bool b = true) { ssl = b; return *this; } + HttpRequest& Path(const String& p) { path = p; return *this; } + HttpRequest& User(const String& u, const String& p) { username = u; password = p; return *this; } + HttpRequest& Digest() { force_digest = true; return *this; } + HttpRequest& Digest(const String& d) { digest = d; return *this; } + HttpRequest& Url(const char *url); + HttpRequest& UrlVar(const char *id, const String& data); + HttpRequest& operator()(const char *id, const String& data) { return UrlVar(id, data); } + HttpRequest& PostData(const String& pd) { postdata = pd; return *this; } + HttpRequest& PostUData(const String& pd) { return PostData(UrlEncode(pd)); } + HttpRequest& Post(const String& data) { POST(); return PostData(data); } + HttpRequest& Post(const char *id, const String& data); + + HttpRequest& Headers(const String& h) { request_headers = h; return *this; } + HttpRequest& ClearHeaders() { return Headers(Null); } + HttpRequest& AddHeaders(const String& h) { request_headers.Cat(h); return *this; } + HttpRequest& Header(const char *id, const String& data); + HttpRequest& Cookie(const String& cookie) { return Header("Cookie", cookie); } + + HttpRequest& StdHeaders(bool sh) { std_headers = sh; return *this; } + HttpRequest& NoStdHeaders() { return StdHeaders(false); } + HttpRequest& Accept(const String& a) { accept = a; return *this; } + HttpRequest& UserAgent(const String& a) { agent = a; return *this; } + HttpRequest& ContentType(const String& a) { contenttype = a; return *this; } + + HttpRequest& Proxy(const String& host, int port) { proxy_host = host; proxy_port = port; return *this; } + HttpRequest& Proxy(const char *p); + HttpRequest& ProxyAuth(const String& u, const String& p) { proxy_username = u; proxy_password = p; return *this; } + + HttpRequest& SSLProxy(const String& host, int port) { ssl_proxy_host = host; ssl_proxy_port = port; return *this; } + HttpRequest& SSLProxy(const char *p); + HttpRequest& SSLProxyAuth(const String& u, const String& p) { ssl_proxy_username = u; ssl_proxy_password = p; return *this; } + + bool IsSocketError() const { return TcpSocket::IsError(); } + bool IsHttpError() const { return !IsNull(error) ; } + bool IsError() const { return IsSocketError() || IsHttpError(); } + String GetErrorDesc() const { return IsSocketError() ? TcpSocket::GetErrorDesc() : error; } + void ClearError() { TcpSocket::ClearError(); error.Clear(); } + + String GetHeader(const char *s) { return header[s]; } + String operator[](const char *s) { return GetHeader(s); } + String GetRedirectUrl(); + int GetContentLength(); + int GetStatusCode() const { return status_code; } + String GetReasonPhrase() const { return reason_phrase; } + + String GetContent() const { return body; } + String operator~() const { return GetContent(); } + operator String() const { return GetContent(); } + void ClearContent() { body.Clear(); } + + enum Phase { + START, DNS, + SSLPROXYREQUEST, SSLPROXYRESPONSE, SSLHANDSHAKE, + REQUEST, HEADER, BODY, + CHUNK_HEADER, CHUNK_BODY, TRAILER, + FINISHED, FAILED + }; + + bool Do(); + int GetPhase() const { return phase; } + String GetPhaseName() const; + bool InProgress() const { return phase != FAILED && phase != FINISHED; } + bool IsFailure() const { return phase == FAILED; } + bool IsSuccess() const { return phase == FINISHED && status_code >= 200 && status_code < 300; } + + String Execute(); + + HttpRequest(); + HttpRequest(const char *url); + + static void Trace(bool b = true); +}; diff --git a/uppsrc/Core/WebUtil.cpp b/uppsrc/Core/InetUtil.cpp similarity index 100% rename from uppsrc/Core/WebUtil.cpp rename to uppsrc/Core/InetUtil.cpp diff --git a/uppsrc/Core/SSL/InitExit.cpp b/uppsrc/Core/SSL/InitExit.cpp index b38d3d476..3b1c1de65 100644 --- a/uppsrc/Core/SSL/InitExit.cpp +++ b/uppsrc/Core/SSL/InitExit.cpp @@ -71,6 +71,8 @@ EXITBLOCK ERR_free_strings(); } +#ifdef _MULTITHREADED + static thread__ bool sThreadInit; static thread__ void (*sPrevExit)(); @@ -89,4 +91,10 @@ void SslInitThread() sPrevExit = Thread::AtExit(sslExitThread); } +#else + +void SslInitThread() {} + +#endif + END_UPP_NAMESPACE diff --git a/uppsrc/Core/SSL/Socket.cpp b/uppsrc/Core/SSL/Socket.cpp index 88640aa83..da56b0e69 100644 --- a/uppsrc/Core/SSL/Socket.cpp +++ b/uppsrc/Core/SSL/Socket.cpp @@ -1,21 +1,24 @@ #include "SSL.h" -#define LLOG(x) DLOG(x) +#define LLOG(x) // DLOG(x) NAMESPACE_UPP struct TcpSocket::SSLImp : TcpSocket::SSL { - virtual bool Start(); - virtual bool Wait(dword flags); - virtual int Send(const void *buffer, int maxlen); - virtual int Recv(void *buffer, int maxlen); - virtual void Close(); + virtual bool Start(); + virtual bool Wait(dword flags); + virtual int Send(const void *buffer, int maxlen); + virtual int Recv(void *buffer, int maxlen); + virtual void Close(); + virtual dword Handshake(); TcpSocket& socket; SslContext context; ::SSL *ssl; SslCertificate cert; + int GetErrorCode(int res); + String GetErrorText(int code); void SetSSLError(const char *context); void SetSSLResError(const char *context, int res); bool IsAgain(int res) const; @@ -53,14 +56,14 @@ void TcpSocket::SSLImp::SetSSLError(const char *context) const char *TcpSocketErrorDesc(int code); -void TcpSocket::SSLImp::SetSSLResError(const char *context, int res) +int TcpSocket::SSLImp::GetErrorCode(int res) +{ + return SSL_get_error(ssl, res); +} + +String TcpSocket::SSLImp::GetErrorText(int code) { - int code = SSL_get_error(ssl, res); String out; - if(code == SSL_ERROR_SYSCALL) { - socket.SetSockError(context); - return; - } switch(code) { #define SSLERR(c) case c: out = #c; break; SSLERR(SSL_ERROR_NONE) @@ -76,7 +79,17 @@ void TcpSocket::SSLImp::SetSSLResError(const char *context, int res) #endif default: out = "unknown code"; break; } - socket.SetSockError(context, code, out); + return out; +} + +void TcpSocket::SSLImp::SetSSLResError(const char *context, int res) +{ + int code = GetErrorCode(res); + if(code == SSL_ERROR_SYSCALL) { + socket.SetSockError(context); + return; + } + socket.SetSockError(context, code, GetErrorText(code)); } bool TcpSocket::SSLImp::IsAgain(int res) const @@ -88,7 +101,7 @@ bool TcpSocket::SSLImp::IsAgain(int res) const res == SSL_ERROR_WANT_ACCEPT; } -bool TcpSocket::SSLImp::Start() // TIMEOUTS!!! +bool TcpSocket::SSLImp::Start() { LLOG("SSL Start"); @@ -96,16 +109,7 @@ bool TcpSocket::SSLImp::Start() // TIMEOUTS!!! SetSSLError("Start: SSL context."); return false; } - -/* - while(!socket.Wait(WAIT_WRITE)) { - DLOG("Waiting for connect"); - Sleep(1); - } - DLOG("Connected"); - SSL_CTX *context = SSL_CTX_new (SSLv3_client_method()); -*/ - +// context.VerifyPeer(); if(!(ssl = SSL_new(context))) { SetSSLError("Start: SSL_new"); return false; @@ -114,32 +118,43 @@ bool TcpSocket::SSLImp::Start() // TIMEOUTS!!! SetSSLError("Start: SSL_set_fd"); return false; } + return Handshake(); +} + +dword TcpSocket::SSLImp::Handshake() +{ int res; - if(socket.mode == ACCEPT) { -// SSL_set_accept_state(ssl); - int res = SSL_accept(ssl); - if(res <= 0 && !IsAgain(res)) { - SetSSLResError("Start: SSL_accept", res); - return false; - } - } - else { -// SSL_set_connect_state(ssl); - for(;;) { - res = SSL_connect(ssl); - if(res > 0) - break; - DDUMP(res); - DDUMP(IsAgain(res)); - if(res <= 0 && !IsAgain(res)) { - SetSSLResError("Start: SSL_connect", res); - return false; - } - Sleep(100); - } + if(socket.mode == ACCEPT) + res = SSL_accept(ssl); + else + if(socket.mode == CONNECT) + res = SSL_connect(ssl); + else + return 0; + if(res <= 0) { + int code = GetErrorCode(res); + if(code == SSL_ERROR_WANT_READ) + return WAIT_READ; + if(code == SSL_ERROR_WANT_WRITE) + return WAIT_WRITE; + SetSSLResError("SSL handshake", res); + return 0; } + socket.mode = SSL_CONNECTED; cert.Set(SSL_get_peer_certificate(ssl)); - return true; + SSLInfo& f = socket.sslinfo.Create(); + f.cipher = SSL_get_cipher(ssl); + if(!cert.IsEmpty()) { + f.cert_avail = true; + f.cert_subject = cert.GetSubjectName(); + f.cert_issuer = cert.GetIssuerName(); + f.cert_serial = cert.GetSerialNumber(); + f.cert_notbefore = cert.GetNotBefore(); + f.cert_notafter = cert.GetNotAfter(); + f.cert_version = cert.GetVersion(); + f.cert_verified = SSL_get_verify_result(ssl) == X509_V_OK; + } + return 0; } bool TcpSocket::SSLImp::Wait(dword flags) @@ -156,6 +171,9 @@ int TcpSocket::SSLImp::Send(const void *buffer, int maxlen) int res = SSL_write(ssl, (const char *)buffer, maxlen); if(res > 0) return res; + if(res == 0) + socket.is_eof = true; + else if(!IsAgain(res)) SetSSLResError("SSL_write", res); return 0; @@ -167,9 +185,10 @@ int TcpSocket::SSLImp::Recv(void *buffer, int maxlen) int res = SSL_read(ssl, (char *)buffer, maxlen); if(res > 0) return res; - - socket.is_eof = true; - if(res && !IsAgain(res)) + if(res == 0) + socket.is_eof = true; + else + if(!IsAgain(res)) SetSSLResError("SSL_read", res); return 0; } @@ -183,255 +202,4 @@ void TcpSocket::SSLImp::Close() ssl = NULL; } -#if 0 -class SSLSocketData : public TcpSocket::Data -{ -public: - SSLSocketData(SslContext& context); - virtual ~SSLSocketData(); - - bool OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, - int timeout, bool is_blocking); - bool Secure(); - bool OpenAccept(SOCKET connection, bool nodelay, bool blocking); - - virtual int GetKind() const { return SOCKKIND_SSL; } - virtual bool Peek(int timeout_msec, bool write); - virtual int Read(void *buf, int amount); - virtual int Write(const void *buf, int amount); - virtual bool Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec); - virtual bool Close(int timeout_msec); - virtual Value GetInfo(String info) const; - - void SetSSLError(const char *context); - void SetSSLResError(const char *context, int res); - -public: - SslContext& ssl_context; - SSL *ssl; - SslCertificate cert; -}; - -SSLSocketData::SSLSocketData(SslContext& ssl_context) -: ssl_context(ssl_context) -{ - SSLInit().AddThread(); - ssl = NULL; -} - -SSLSocketData::~SSLSocketData() -{ - Close(0); -} - -void SSLSocketData::SetSSLError(const char *context) -{ - if(sock) { - int code; - String text = SSLGetLastError(code); - SetSockError(context, code, text); - } -} - -void SSLSocketData::SetSSLResError(const char *context, int res) -{ - if(sock) { - int code = SSL_get_error(ssl, res); - String out; - switch(code) { - #define SSLERR(c) case c: out = #c; break; - SSLERR(SSL_ERROR_NONE) - SSLERR(SSL_ERROR_SSL) - SSLERR(SSL_ERROR_WANT_READ) - SSLERR(SSL_ERROR_WANT_WRITE) - SSLERR(SSL_ERROR_WANT_X509_LOOKUP) - SSLERR(SSL_ERROR_SYSCALL) - SSLERR(SSL_ERROR_ZERO_RETURN) - SSLERR(SSL_ERROR_WANT_CONNECT) - #ifdef PLATFORM_WIN32 - SSLERR(SSL_ERROR_WANT_ACCEPT) - #endif - default: out = "unknown code"; break; - } - SetSockError(context, code, out); - } -} - -bool SSLSocketData::Peek(int timeout_msec, bool write) -{ - if(ssl && !write && SSL_pending(ssl) > 0) - return true; - return Data::Peek(timeout_msec, write); -} - -int SSLSocketData::Read(void *buf, int amount) -{ - if(!ssl) - return Data::Read(buf, amount); - int res = SSL_read(ssl, (char *)buf, amount); - if(res == 0) { - is_eof = true; - if(SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN) - return 0; - } - if(res <= 0) - SetSSLResError("SSL_read", res); -#ifndef NOFAKEERROR - if(fake_error && res > 0) { - if((fake_error -= res) <= 0) { - fake_error = 0; - SetSockError("SSL_read", 0, "fake error"); - return -1; - } - else - RLOG("SSLSocketData::Read: fake error after " << fake_error); - } -#endif - return res; -} - -int SSLSocketData::Write(const void *buf, int amount) -{ - if(!ssl) - return Data::Write(buf, amount); - int res = SSL_write(ssl, (const char *)buf, amount); - if(res <= 0) - SetSSLResError("SSL_write", res); - return res; -} - -bool SSLSocketData::OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, - int timeout, bool blocking) -{ - return Data::OpenClient(host, port, nodelay, my_addr, timeout, /*blocking*/true); -} - -bool SSLSocketData::Secure() -{ - if(!(ssl = SSL_new(ssl_context))) - { - SetSSLError("OpenClient / SSL_new"); - return false; - } - if(!SSL_set_fd(ssl, socket)) - { - SetSSLError("OpenClient / SSL_set_fd"); - return false; - } - SSL_set_connect_state(ssl); - int res = SSL_connect(ssl); - if(res <= 0) - { - SetSSLResError("OpenClient / SSL_connect", res); - return false; - } - cert.Set(SSL_get_peer_certificate(ssl)); - return true; -} - -bool SSLSocketData::OpenAccept(SOCKET conn, bool nodelay, bool blocking) -{ - Attach(conn, nodelay, blocking); - if(!(ssl = SSL_new(ssl_context))) - { - SetSSLError("Accept / SSL_new"); - return false; - } - if(!SSL_set_fd(ssl, socket)) - { - SetSSLError("Accept / SSL_set_fd"); - return false; - } - SSL_set_accept_state(ssl); - int res = SSL_accept(ssl); - if(res <= 0) - { - SetSSLResError("Accept / SSL_accept", res); - return false; - } - cert.Set(SSL_get_peer_certificate(ssl)); - return true; -} - -bool SSLSocketData::Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec) -{ - SOCKET connection = AcceptRaw(ipaddr, timeout_msec); - if(connection == INVALID_SOCKET) - return false; - One data = new SSLSocketData(ssl_context); - if(!data->OpenAccept(connection, nodelay, is_blocking)) - return false; - socket.Attach(-data); - return true; -} - -bool SSLSocketData::Close(int timeout_msec) -{ - if(ssl) - SSL_shutdown(ssl); - bool res = Data::Close(timeout_msec); - if(ssl) { - SSL_free(ssl); - ssl = NULL; - } - return res; -} - -Value SSLSocketData::GetInfo(String info) const -{ - if(info == SSLInfoCipher()) return SSL_get_cipher(ssl); - if(info == SSLInfoCertAvail()) return cert.IsEmpty() ? 0 : 1; - if(info == SSLInfoCertVerified()) return SSL_get_verify_result(ssl) == X509_V_OK ? 1 : 0; - if(info == SSLInfoCertSubjectName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSubjectName(); - if(info == SSLInfoCertIssuerName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetIssuerName(); - if(info == SSLInfoCertNotBefore()) return cert.IsEmpty() ? Date(Null) : cert.GetNotBefore(); - if(info == SSLInfoCertNotAfter()) return cert.IsEmpty() ? Date(Null) : cert.GetNotAfter(); - if(info == SSLInfoCertVersion()) return cert.IsEmpty() ? int(Null) : cert.GetVersion(); - if(info == SSLInfoCertSerialNumber()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSerialNumber(); - - return Data::GetInfo(info); -} - -bool SSLServerSocket(Socket& socket, SslContext& ssl_context, int port, bool nodelay, int listen_count, bool blocking) -{ - One data = new SSLSocketData(ssl_context); - if(!data->OpenServer(port, nodelay, listen_count, blocking)) - return false; - socket.Attach(-data); - return true; -} - -bool SSLClientSocket(Socket& socket, SslContext& ssl_context, const char *host, int port, bool nodelay, - dword *my_addr, int timeout, bool blocking) -{ - One data = new SSLSocketData(ssl_context); - if(!data->OpenClient(host, port, nodelay, my_addr, timeout, blocking)) - return false; - if(!data->Secure()) - return false; - socket.Attach(-data); - return true; -} - -bool SSLClientSocketUnsecured(Socket& socket, SslContext& ssl_context, const char *host, - int port, bool nodelay, dword *my_addr, int timeout, - bool is_blocking) -{ - One data = new SSLSocketData(ssl_context); - if(data->OpenClientUnsecured(host, port, nodelay, my_addr, timeout, is_blocking)) { - socket.Attach(-data); - return true; - } - return false; -} - -bool SSLSecureSocket(Socket& socket) -{ - SSLSocketData *sd = dynamic_cast(~socket.data); - if(!sd) - return false; - return sd->Secure(); -} -#endif - END_UPP_NAMESPACE diff --git a/uppsrc/Core/Socket.cpp b/uppsrc/Core/Socket.cpp index 5c7de40f0..e8e310265 100644 --- a/uppsrc/Core/Socket.cpp +++ b/uppsrc/Core/Socket.cpp @@ -1,797 +1,821 @@ -#include "Core.h" - -#ifdef PLATFORM_WIN32 -#include - #ifdef COMPILER_MSC - #include - #endif -#include -#endif - -#ifdef PLATFORM_POSIX -#include -#endif - -NAMESPACE_UPP - -#ifdef PLATFORM_WIN32 -#pragma comment(lib, "ws2_32.lib") -#endif - -#define LLOG(x) // DLOG("TCP " << x) - -IpAddrInfo::Entry IpAddrInfo::pool[COUNT]; - -RawMutex IpAddrInfoPoolMutex; - -void IpAddrInfo::EnterPool() -{ - IpAddrInfoPoolMutex.Enter(); -} - -void IpAddrInfo::LeavePool() -{ - IpAddrInfoPoolMutex.Leave(); -} - -int sGetAddrInfo(const char *host, const char *port, addrinfo **result) -{ - addrinfo hints; - memset(&hints, 0, sizeof(addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - - return getaddrinfo(host, port, &hints, result); -} - -rawthread_t rawthread__ IpAddrInfo::Thread(void *ptr) -{ - Entry *entry = (Entry *)ptr; - EnterPool(); - if(entry->status == WORKING) { - char host[1025]; - char port[257]; - strcpy(host, entry->host); - strcpy(port, entry->port); - LeavePool(); - addrinfo *result; - if(sGetAddrInfo(host, port, &result) == 0 && result) { - EnterPool(); - if(entry->status == WORKING) { - entry->addr = result; - entry->status = RESOLVED; - } - else { - freeaddrinfo(result); - entry->status = EMPTY; - } - } - else { - EnterPool(); - if(entry->status == CANCELED) - entry->status = EMPTY; - else - entry->status = FAILED; - } - } - LeavePool(); - return 0; -} - -bool IpAddrInfo::Execute(const String& host, int port) -{ - Clear(); - entry = exe; - addrinfo *result; - entry->addr = sGetAddrInfo(~host, ~AsString(port), &result) == 0 ? result : NULL; - return entry->addr; -} - -void IpAddrInfo::Start() -{ - if(entry) - return; - EnterPool(); - for(int i = 0; i < COUNT; i++) { - Entry *e = pool + i; - if(e->status == EMPTY) { - entry = e; - e->addr = NULL; - if(host.GetCount() > 1024 || port.GetCount() > 256) - e->status = FAILED; - else { - e->status = WORKING; - e->host = host; - e->port = port; - StartRawThread(&IpAddrInfo::Thread, e); - } - break; - } - } - LeavePool(); -} - -void IpAddrInfo::Start(const String& host_, int port_) -{ - Clear(); - port = AsString(port_); - host = host_; - Start(); -} - -bool IpAddrInfo::InProgress() -{ - if(!entry) { - Start(); - return true; - } - EnterPool(); - int s = entry->status; - LeavePool(); - return s == WORKING; -} - -addrinfo *IpAddrInfo::GetResult() -{ - EnterPool(); - addrinfo *ai = entry ? entry->addr : NULL; - LeavePool(); - return ai; -} - -void IpAddrInfo::Clear() -{ - EnterPool(); - if(entry) { - if(entry->status == RESOLVED && entry->addr) - freeaddrinfo(entry->addr); - if(entry->status == WORKING) - entry->status = CANCELED; - else - entry->status = EMPTY; - entry = NULL; - } - LeavePool(); -} - -IpAddrInfo::IpAddrInfo() -{ - TcpSocket::Init(); - entry = NULL; -} - -#ifdef PLATFORM_POSIX - -#define SOCKERR(x) x - -const char *TcpSocketErrorDesc(int code) -{ - return strerror(code); -} - -int TcpSocket::GetErrorCode() -{ - return errno; -} - -#else - -#define SOCKERR(x) WSA##x - -const char *TcpSocketErrorDesc(int code) -{ - static Tuple2 err[] = { - { WSAEINTR, "Interrupted function call." }, - { WSAEACCES, "Permission denied." }, - { WSAEFAULT, "Bad address." }, - { WSAEINVAL, "Invalid argument." }, - { WSAEMFILE, "Too many open files." }, - { WSAEWOULDBLOCK, "Resource temporarily unavailable." }, - { WSAEINPROGRESS, "Operation now in progress." }, - { WSAEALREADY, "Operation already in progress." }, - { WSAENOTSOCK, "TcpSocket operation on nonsocket." }, - { WSAEDESTADDRREQ, "Destination address required." }, - { WSAEMSGSIZE, "Message too long." }, - { WSAEPROTOTYPE, "Protocol wrong type for socket." }, - { WSAENOPROTOOPT, "Bad protocol option." }, - { WSAEPROTONOSUPPORT, "Protocol not supported." }, - { WSAESOCKTNOSUPPORT, "TcpSocket type not supported." }, - { WSAEOPNOTSUPP, "Operation not supported." }, - { WSAEPFNOSUPPORT, "Protocol family not supported." }, - { WSAEAFNOSUPPORT, "Address family not supported by protocol family." }, - { WSAEADDRINUSE, "Address already in use." }, - { WSAEADDRNOTAVAIL, "Cannot assign requested address." }, - { WSAENETDOWN, "Network is down." }, - { WSAENETUNREACH, "Network is unreachable." }, - { WSAENETRESET, "Network dropped connection on reset." }, - { WSAECONNABORTED, "Software caused connection abort." }, - { WSAECONNRESET, "Connection reset by peer." }, - { WSAENOBUFS, "No buffer space available." }, - { WSAEISCONN, "TcpSocket is already connected." }, - { WSAENOTCONN, "TcpSocket is not connected." }, - { WSAESHUTDOWN, "Cannot send after socket shutdown." }, - { WSAETIMEDOUT, "Connection timed out." }, - { WSAECONNREFUSED, "Connection refused." }, - { WSAEHOSTDOWN, "Host is down." }, - { WSAEHOSTUNREACH, "No route to host." }, - { WSAEPROCLIM, "Too many processes." }, - { WSASYSNOTREADY, "Network subsystem is unavailable." }, - { WSAVERNOTSUPPORTED, "Winsock.dll version out of range." }, - { WSANOTINITIALISED, "Successful WSAStartup not yet performed." }, - { WSAEDISCON, "Graceful shutdown in progress." }, - { WSATYPE_NOT_FOUND, "Class type not found." }, - { WSAHOST_NOT_FOUND, "Host not found." }, - { WSATRY_AGAIN, "Nonauthoritative host not found." }, - { WSANO_RECOVERY, "This is a nonrecoverable error." }, - { WSANO_DATA, "Valid name, no data record of requested type." }, - { WSASYSCALLFAILURE, "System call failure." }, - }; - const Tuple2 *x = FindTuple(err, __countof(err), code); - return x ? x->b : "Unknown error code."; -} - -int TcpSocket::GetErrorCode() -{ - return WSAGetLastError(); -} - -#endif - -void TcpSocketInit() -{ -#if defined(PLATFORM_WIN32) - ONCELOCK { - WSADATA wsadata; - WSAStartup(MAKEWORD(2, 2), &wsadata); - } -#endif -} - -void TcpSocket::Init() -{ - TcpSocketInit(); -} - -void TcpSocket::Reset() -{ - is_eof = false; - socket = INVALID_SOCKET; - ipv6 = false; - ptr = end = buffer; - is_error = false; - is_abort = false; - mode = NONE; - ssl.Clear(); -} - -TcpSocket::TcpSocket() -{ - ClearError(); - Reset(); - timeout = Null; - waitstep = 20; -} - -bool TcpSocket::Open(int family, int type, int protocol) -{ - Init(); - Close(); - ClearError(); - if((socket = ::socket(family, type, protocol)) == INVALID_SOCKET) - return false; - LLOG("TcpSocket::Data::Open() -> " << (int)socket); -#ifdef PLATFORM_WIN32 - u_long arg = 1; - if(ioctlsocket(socket, FIONBIO, &arg)) - SetSockError("ioctlsocket(FIO[N]BIO)"); -#else - if(fcntl(socket, F_SETFL, (fcntl(socket, F_GETFL, 0) | O_NONBLOCK))) - SetSockError("fcntl(O_[NON]BLOCK)"); -#endif - return true; -} - -bool TcpSocket::Listen(int port, int listen_count, bool ipv6_, bool reuse) -{ - Close(); - Init(); - Reset(); - - ipv6 = ipv6_; - if(!Open(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) - return false; - sockaddr_in sin; -#ifdef PLATFORM_WIN32 - SOCKADDR_IN6 sin6; - if(ipv6 && IsWinVista()) -#else - sockaddr_in6 sin6; - if(ipv6) -#endif - { - Zero(sin6); - sin.sin_family = AF_INET6; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = htonl(INADDR_ANY); - } - else { - Zero(sin); - sin.sin_family = AF_INET; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = htonl(INADDR_ANY); - } - if(reuse) { - int optval = 1; - setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&optval, sizeof(optval)); - } - if(bind(socket, ipv6 ? (const sockaddr *)&sin6 : (const sockaddr *)&sin, - ipv6 ? sizeof(sin6) : sizeof(sin))) { - SetSockError(Format("bind(port=%d)", port)); - return false; - } - if(listen(socket, listen_count)) { - SetSockError(Format("listen(port=%d, count=%d)", port, listen_count)); - return false; - } - return true; -} - -bool TcpSocket::Accept(TcpSocket& ls) -{ - Close(); - Init(); - Reset(); - - if(timeout && !ls.WaitRead()) - return false; - if(!Open(ls.ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) - return false; - socket = accept(ls.GetSOCKET(), NULL, NULL); - if(socket == INVALID_SOCKET) { - SetSockError("accept"); - return false; - } - mode = ACCEPT; - return true; -} - -String TcpSocket::GetPeerAddr() const -{ - if(!IsOpen()) - return Null; - sockaddr_in addr; - socklen_t l = sizeof(addr); - if(getpeername(socket, (sockaddr *)&addr, &l) != 0) - return Null; - if(l > sizeof(addr)) - return Null; -#ifdef PLATFORM_WIN32 - return inet_ntoa(addr.sin_addr); -#else - char h[200]; - return inet_ntop(AF_INET, &addr.sin_addr, h, 200); -#endif -} - -void TcpSocket::NoDelay() -{ - ASSERT(IsOpen()); - int __true = 1; - LLOG("NoDelay(" << (int)socket << ")"); - if(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char *)&__true, sizeof(__true))) - SetSockError("setsockopt(TCP_NODELAY)"); -} - -void TcpSocket::Linger(int msecs) -{ - ASSERT(IsOpen()); - linger ls; - ls.l_onoff = !IsNull(msecs) ? 1 : 0; - ls.l_linger = !IsNull(msecs) ? (msecs + 999) / 1000 : 0; - if(setsockopt(socket, SOL_SOCKET, SO_LINGER, - reinterpret_cast(&ls), sizeof(ls))) - SetSockError("setsockopt(SO_LINGER)"); -} - -void TcpSocket::Attach(SOCKET s) -{ - Close(); - socket = s; -} - -bool TcpSocket::RawConnect(addrinfo *rp) -{ - if(!rp) { - SetSockError("connect", -1, "not found"); - return false; - } - for(;;) { - if(rp && Open(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) { - if(connect(socket, rp->ai_addr, rp->ai_addrlen) == 0 || - GetErrorCode() == SOCKERR(EINPROGRESS) || GetErrorCode() == SOCKERR(EWOULDBLOCK)) - break; - Close(); - } - rp = rp->ai_next; - if(!rp) { - SetSockError("connect", -1, "failed"); - return false; - } - } - mode = CONNECT; - return true; -} - - -bool TcpSocket::Connect(IpAddrInfo& info) -{ - LLOG("TCP Connect addrinfo"); - Init(); - Reset(); - addrinfo *result = info.GetResult(); - return result && RawConnect(result); -} - -bool TcpSocket::Connect(const char *host, int port) -{ - LLOG("TCP Connect(" << host << ':' << port << ')'); - - Init(); - Reset(); - IpAddrInfo info; - if(!info.Execute(host, port)) { - SetSockError(Format("getaddrinfo(%s) failed", host)); - return false; - } - return Connect(info); -} - -void TcpSocket::RawClose() -{ - LLOG("TCP close " << (int)socket); - if(socket != INVALID_SOCKET) { - int res; -#if defined(PLATFORM_WIN32) - res = closesocket(socket); -#elif defined(PLATFORM_POSIX) - res = close(socket); -#else - #error Unsupported platform -#endif - if(res && !IsError()) - SetSockError("close"); - socket = INVALID_SOCKET; - } -} - -void TcpSocket::Close() -{ - if(ssl) - ssl->Close(); - else - RawClose(); - ssl.Clear(); -} - -bool TcpSocket::WouldBlock() -{ - int c = GetErrorCode(); -#ifdef PLATFORM_POSIX - return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(EAGAIN); -#endif -#ifdef PLATFORM_WIN32 - return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(ENOTCONN); -#endif -} - -int TcpSocket::RawSend(const void *buf, int amount) -{ - int res = send(socket, (const char *)buf, amount, 0); - if(res < 0 && WouldBlock()) - res = 0; - else - if(res == 0 || res < 0) - SetSockError("send"); - return res; -} - -int TcpSocket::Send(const void *buf, int amount) -{ - return ssl ? ssl->Send(buf, amount) : RawSend(buf, amount); -} - -void TcpSocket::Shutdown() -{ - ASSERT(IsOpen()); - if(shutdown(socket, SD_SEND)) - SetSockError("shutdown(SD_SEND)"); -} - -String TcpSocket::GetHostName() -{ - Init(); - char buffer[256]; - gethostname(buffer, __countof(buffer)); - return buffer; -} - -bool TcpSocket::RawWait(dword flags) -{ - LLOG("Wait(" << timeout << ", " << flags << ")"); - if((flags & WAIT_READ) && ptr != end) - return true; - int end_time = msecs() + timeout; - if(socket == INVALID_SOCKET) - return false; - for(;;) { - if(IsError() || IsAbort()) - return false; - int to = end_time - msecs(); - if(WhenWait) - to = waitstep; - timeval *tvalp = NULL; - timeval tval; - if(!IsNull(timeout) || WhenWait) { - to = max(to, 0); - tval.tv_sec = to / 1000; - tval.tv_usec = 1000 * (to % 1000); - tvalp = &tval; - } - fd_set fdset[1]; - FD_ZERO(fdset); - FD_SET(socket, fdset); - int avail = select((int)socket + 1, - flags & WAIT_READ ? fdset : NULL, - flags & WAIT_WRITE ? fdset : NULL, - flags & WAIT_EXCEPTION ? fdset : NULL, tvalp); - LLOG("Wait select avail: " << avail); - if(avail < 0) { - SetSockError("wait"); - return false; - } - if(avail > 0) - return true; - if(to <= 0 && timeout) { - return false; - } - WhenWait(); - if(timeout == 0) - return false; - } -} - -bool TcpSocket::Wait(dword flags) -{ - return ssl ? ssl->Wait(flags) : RawWait(flags); -} - -int TcpSocket::Put(const char *s, int length) -{ - LLOG("Put " << socket << ": " << length); - ASSERT(IsOpen()); - if(length < 0 && s) - length = (int)strlen(s); - if(!s || length <= 0 || IsError() || IsAbort()) - return 0; - done = 0; - bool peek = false; - while(done < length) { - if(peek && !WaitWrite()) - return done; - peek = false; - int count = Send(s + done, length - done); - if(IsError() || timeout == 0 && count == 0 && peek) - return done; - if(count > 0) - done += count; - else - peek = true; - } - LLOG("//Put() -> " << done); - return done; -} - -int TcpSocket::RawRecv(void *buf, int amount) -{ - int res = recv(socket, (char *)buf, amount, 0); - if(res == 0) - is_eof = true; - else - if(res < 0 && WouldBlock()) - res = 0; - else - if(res < 0) - SetSockError("recv"); - LLOG("recv(" << socket << "): " << res << " bytes: " - << AsCString((char *)buf, (char *)buf + min(res, 16)) - << (res ? "" : IsEof() ? ", EOF" : ", WOULDBLOCK")); - return res; -} - -int TcpSocket::Recv(void *buffer, int maxlen) -{ - return ssl ? ssl->Recv(buffer, maxlen) : RawRecv(buffer, maxlen); -} - -void TcpSocket::ReadBuffer() -{ - ptr = end = buffer; - if(WaitRead()) - end = buffer + Recv(buffer, BUFFERSIZE); -} - -int TcpSocket::Get_() -{ - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return -1; - ReadBuffer(); - return ptr < end ? *ptr++ : -1; -} - -int TcpSocket::Peek_() -{ - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return -1; - ReadBuffer(); - return ptr < end ? *ptr : -1; -} - -int TcpSocket::Get(void *buffer, int count) -{ - LLOG("Get " << count); - - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return 0; - - String out; - int l = end - ptr; - done = 0; - if(l > 0) - if(l < count) { - memcpy(buffer, ptr, l); - done += l; - ptr = end; - } - else { - memcpy(buffer, ptr, count); - ptr += count; - return count; - } - while(done < count && !IsError() && !IsEof()) { - if(!WaitRead()) - break; - int part = Recv((char *)buffer + done, count - done); - if(part > 0) - done += part; - if(timeout == 0) - break; - } - return done; -} - -String TcpSocket::Get(int count) -{ - if(count == 0) - return Null; - StringBuffer out(count); - int done = Get(out, count); - if(!done && IsEof()) - return String::GetVoid(); - out.SetLength(done); - return out; -} - -String TcpSocket::GetLine(int maxlen) -{ - String ln; - for(;;) { - int c = Peek(); - if(c < 0) - return String::GetVoid(); - Get(); - if(c == '\n') - return ln; - if(c != '\r') - ln.Cat(c); - } -} - -void TcpSocket::SetSockError(const char *context, int code, const char *errdesc) -{ - errorcode = code; - errordesc.Clear(); - if(socket != INVALID_SOCKET) - errordesc << "socket(" << (int)socket << ") / "; - errordesc << context << ": " << errdesc; - is_error = true; - LLOG("TCP ERROR " << errordesc); -} - -void TcpSocket::SetSockError(const char *context, const char *errdesc) -{ - SetSockError(context, GetErrorCode(), errdesc); -} - -void TcpSocket::SetSockError(const char *context) -{ - SetSockError(context, TcpSocketErrorDesc(GetErrorCode())); -} - -TcpSocket::SSL *(*TcpSocket::CreateSSL)(TcpSocket& socket); - -bool TcpSocket::StartSSL() -{ - ASSERT(IsOpen()); - if(!CreateSSL) { - SetSockError("StartSSL", -1, "Missing SSL support (Core/SSL)"); - return false; - } - if(!IsOpen()) { - SetSockError("StartSSL", -1, "Socket is not open"); - return false; - } - if(!IsOpen()) { - SetSockError("StartSSL", -1, "Socket is not connected"); - return false; - } - ssl = (*CreateSSL)(*this); - if(!ssl->Start()) { - ssl.Clear(); - return false; - } - return true; -} - -int SocketWaitEvent::Wait(int timeout) -{ - FD_ZERO(read); - FD_ZERO(write); - FD_ZERO(exception); - int maxindex = -1; - for(int i = 0; i < socket.GetCount(); i++) { - const Tuple2& s = socket[i]; - if(s.a >= 0) { - const Tuple2& s = socket[i]; - if(s.b & WAIT_READ) - FD_SET(s.a, read); - if(s.b & WAIT_WRITE) - FD_SET(s.a, write); - if(s.b & WAIT_EXCEPTION) - FD_SET(s.a, exception); - maxindex = max(s.a, maxindex); - } - } - timeval *tvalp = NULL; - timeval tval; - if(!IsNull(timeout)) { - tval.tv_sec = timeout / 1000; - tval.tv_usec = 1000 * (timeout % 1000); - tvalp = &tval; - } - return select(maxindex + 1, read, write, exception, tvalp); -} - -dword SocketWaitEvent::Get(int i) const -{ - int s = socket[i].a; - if(s < 0) - return 0; - dword events = 0; - if(FD_ISSET(s, read)) - events |= WAIT_READ; - if(FD_ISSET(s, write)) - events |= WAIT_WRITE; - if(FD_ISSET(s, exception)) - events |= WAIT_EXCEPTION; - return events; -} - -SocketWaitEvent::SocketWaitEvent() -{ - FD_ZERO(read); - FD_ZERO(write); - FD_ZERO(exception); -} - -END_UPP_NAMESPACE +#include "Core.h" + +#ifdef PLATFORM_WIN32 +#include + #ifdef COMPILER_MSC + #include + #endif +#include +#endif + +#ifdef PLATFORM_POSIX +#include +#endif + +NAMESPACE_UPP + +#ifdef PLATFORM_WIN32 +#pragma comment(lib, "ws2_32.lib") +#endif + +#define LLOG(x) // DLOG("TCP " << x) + +IpAddrInfo::Entry IpAddrInfo::pool[COUNT]; + +RawMutex IpAddrInfoPoolMutex; + +void IpAddrInfo::EnterPool() +{ + IpAddrInfoPoolMutex.Enter(); +} + +void IpAddrInfo::LeavePool() +{ + IpAddrInfoPoolMutex.Leave(); +} + +int sGetAddrInfo(const char *host, const char *port, addrinfo **result) +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + return getaddrinfo(host, port, &hints, result); +} + +rawthread_t rawthread__ IpAddrInfo::Thread(void *ptr) +{ + Entry *entry = (Entry *)ptr; + EnterPool(); + if(entry->status == WORKING) { + char host[1025]; + char port[257]; + strcpy(host, entry->host); + strcpy(port, entry->port); + LeavePool(); + addrinfo *result; + if(sGetAddrInfo(host, port, &result) == 0 && result) { + EnterPool(); + if(entry->status == WORKING) { + entry->addr = result; + entry->status = RESOLVED; + } + else { + freeaddrinfo(result); + entry->status = EMPTY; + } + } + else { + EnterPool(); + if(entry->status == CANCELED) + entry->status = EMPTY; + else + entry->status = FAILED; + } + } + LeavePool(); + return 0; +} + +bool IpAddrInfo::Execute(const String& host, int port) +{ + Clear(); + entry = exe; + addrinfo *result; + entry->addr = sGetAddrInfo(~host, ~AsString(port), &result) == 0 ? result : NULL; + return entry->addr; +} + +void IpAddrInfo::Start() +{ + if(entry) + return; + EnterPool(); + for(int i = 0; i < COUNT; i++) { + Entry *e = pool + i; + if(e->status == EMPTY) { + entry = e; + e->addr = NULL; + if(host.GetCount() > 1024 || port.GetCount() > 256) + e->status = FAILED; + else { + e->status = WORKING; + e->host = host; + e->port = port; + StartRawThread(&IpAddrInfo::Thread, e); + } + break; + } + } + LeavePool(); +} + +void IpAddrInfo::Start(const String& host_, int port_) +{ + Clear(); + port = AsString(port_); + host = host_; + Start(); +} + +bool IpAddrInfo::InProgress() +{ + if(!entry) { + Start(); + return true; + } + EnterPool(); + int s = entry->status; + LeavePool(); + return s == WORKING; +} + +addrinfo *IpAddrInfo::GetResult() +{ + EnterPool(); + addrinfo *ai = entry ? entry->addr : NULL; + LeavePool(); + return ai; +} + +void IpAddrInfo::Clear() +{ + EnterPool(); + if(entry) { + if(entry->status == RESOLVED && entry->addr) + freeaddrinfo(entry->addr); + if(entry->status == WORKING) + entry->status = CANCELED; + else + entry->status = EMPTY; + entry = NULL; + } + LeavePool(); +} + +IpAddrInfo::IpAddrInfo() +{ + TcpSocket::Init(); + entry = NULL; +} + +#ifdef PLATFORM_POSIX + +#define SOCKERR(x) x + +const char *TcpSocketErrorDesc(int code) +{ + return strerror(code); +} + +int TcpSocket::GetErrorCode() +{ + return errno; +} + +#else + +#define SOCKERR(x) WSA##x + +const char *TcpSocketErrorDesc(int code) +{ + static Tuple2 err[] = { + { WSAEINTR, "Interrupted function call." }, + { WSAEACCES, "Permission denied." }, + { WSAEFAULT, "Bad address." }, + { WSAEINVAL, "Invalid argument." }, + { WSAEMFILE, "Too many open files." }, + { WSAEWOULDBLOCK, "Resource temporarily unavailable." }, + { WSAEINPROGRESS, "Operation now in progress." }, + { WSAEALREADY, "Operation already in progress." }, + { WSAENOTSOCK, "TcpSocket operation on nonsocket." }, + { WSAEDESTADDRREQ, "Destination address required." }, + { WSAEMSGSIZE, "Message too long." }, + { WSAEPROTOTYPE, "Protocol wrong type for socket." }, + { WSAENOPROTOOPT, "Bad protocol option." }, + { WSAEPROTONOSUPPORT, "Protocol not supported." }, + { WSAESOCKTNOSUPPORT, "TcpSocket type not supported." }, + { WSAEOPNOTSUPP, "Operation not supported." }, + { WSAEPFNOSUPPORT, "Protocol family not supported." }, + { WSAEAFNOSUPPORT, "Address family not supported by protocol family." }, + { WSAEADDRINUSE, "Address already in use." }, + { WSAEADDRNOTAVAIL, "Cannot assign requested address." }, + { WSAENETDOWN, "Network is down." }, + { WSAENETUNREACH, "Network is unreachable." }, + { WSAENETRESET, "Network dropped connection on reset." }, + { WSAECONNABORTED, "Software caused connection abort." }, + { WSAECONNRESET, "Connection reset by peer." }, + { WSAENOBUFS, "No buffer space available." }, + { WSAEISCONN, "TcpSocket is already connected." }, + { WSAENOTCONN, "TcpSocket is not connected." }, + { WSAESHUTDOWN, "Cannot send after socket shutdown." }, + { WSAETIMEDOUT, "Connection timed out." }, + { WSAECONNREFUSED, "Connection refused." }, + { WSAEHOSTDOWN, "Host is down." }, + { WSAEHOSTUNREACH, "No route to host." }, + { WSAEPROCLIM, "Too many processes." }, + { WSASYSNOTREADY, "Network subsystem is unavailable." }, + { WSAVERNOTSUPPORTED, "Winsock.dll version out of range." }, + { WSANOTINITIALISED, "Successful WSAStartup not yet performed." }, + { WSAEDISCON, "Graceful shutdown in progress." }, + { WSATYPE_NOT_FOUND, "Class type not found." }, + { WSAHOST_NOT_FOUND, "Host not found." }, + { WSATRY_AGAIN, "Nonauthoritative host not found." }, + { WSANO_RECOVERY, "This is a nonrecoverable error." }, + { WSANO_DATA, "Valid name, no data record of requested type." }, + { WSASYSCALLFAILURE, "System call failure." }, + }; + const Tuple2 *x = FindTuple(err, __countof(err), code); + return x ? x->b : "Unknown error code."; +} + +int TcpSocket::GetErrorCode() +{ + return WSAGetLastError(); +} + +#endif + +void TcpSocketInit() +{ +#if defined(PLATFORM_WIN32) + ONCELOCK { + WSADATA wsadata; + WSAStartup(MAKEWORD(2, 2), &wsadata); + } +#endif +} + +void TcpSocket::Init() +{ + TcpSocketInit(); +} + +void TcpSocket::Reset() +{ + is_eof = false; + socket = INVALID_SOCKET; + ipv6 = false; + ptr = end = buffer; + is_error = false; + is_abort = false; + mode = NONE; + ssl.Clear(); + sslinfo.Clear(); +} + +TcpSocket::TcpSocket() +{ + ClearError(); + Reset(); + timeout = Null; + waitstep = 20; +} + +bool TcpSocket::Open(int family, int type, int protocol) +{ + Init(); + Close(); + ClearError(); + if((socket = ::socket(family, type, protocol)) == INVALID_SOCKET) + return false; + LLOG("TcpSocket::Data::Open() -> " << (int)socket); +#ifdef PLATFORM_WIN32 + u_long arg = 1; + if(ioctlsocket(socket, FIONBIO, &arg)) + SetSockError("ioctlsocket(FIO[N]BIO)"); +#else + if(fcntl(socket, F_SETFL, (fcntl(socket, F_GETFL, 0) | O_NONBLOCK))) + SetSockError("fcntl(O_[NON]BLOCK)"); +#endif + return true; +} + +bool TcpSocket::Listen(int port, int listen_count, bool ipv6_, bool reuse) +{ + Close(); + Init(); + Reset(); + + ipv6 = ipv6_; + if(!Open(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) + return false; + sockaddr_in sin; +#ifdef PLATFORM_WIN32 + SOCKADDR_IN6 sin6; + if(ipv6 && IsWinVista()) +#else + sockaddr_in6 sin6; + if(ipv6) +#endif + { + Zero(sin6); + sin.sin_family = AF_INET6; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_ANY); + } + else { + Zero(sin); + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_ANY); + } + if(reuse) { + int optval = 1; + setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&optval, sizeof(optval)); + } + if(bind(socket, ipv6 ? (const sockaddr *)&sin6 : (const sockaddr *)&sin, + ipv6 ? sizeof(sin6) : sizeof(sin))) { + SetSockError(Format("bind(port=%d)", port)); + return false; + } + if(listen(socket, listen_count)) { + SetSockError(Format("listen(port=%d, count=%d)", port, listen_count)); + return false; + } + return true; +} + +bool TcpSocket::Accept(TcpSocket& ls) +{ + Close(); + Init(); + Reset(); + + if(timeout && !ls.WaitRead()) + return false; + if(!Open(ls.ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) + return false; + socket = accept(ls.GetSOCKET(), NULL, NULL); + if(socket == INVALID_SOCKET) { + SetSockError("accept"); + return false; + } + mode = ACCEPT; + return true; +} + +String TcpSocket::GetPeerAddr() const +{ + if(!IsOpen()) + return Null; + sockaddr_in addr; + socklen_t l = sizeof(addr); + if(getpeername(socket, (sockaddr *)&addr, &l) != 0) + return Null; + if(l > sizeof(addr)) + return Null; +#ifdef PLATFORM_WIN32 + return inet_ntoa(addr.sin_addr); +#else + char h[200]; + return inet_ntop(AF_INET, &addr.sin_addr, h, 200); +#endif +} + +void TcpSocket::NoDelay() +{ + ASSERT(IsOpen()); + int __true = 1; + LLOG("NoDelay(" << (int)socket << ")"); + if(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char *)&__true, sizeof(__true))) + SetSockError("setsockopt(TCP_NODELAY)"); +} + +void TcpSocket::Linger(int msecs) +{ + ASSERT(IsOpen()); + linger ls; + ls.l_onoff = !IsNull(msecs) ? 1 : 0; + ls.l_linger = !IsNull(msecs) ? (msecs + 999) / 1000 : 0; + if(setsockopt(socket, SOL_SOCKET, SO_LINGER, + reinterpret_cast(&ls), sizeof(ls))) + SetSockError("setsockopt(SO_LINGER)"); +} + +void TcpSocket::Attach(SOCKET s) +{ + Close(); + socket = s; +} + +bool TcpSocket::RawConnect(addrinfo *rp) +{ + if(!rp) { + SetSockError("connect", -1, "not found"); + return false; + } + for(;;) { + if(rp && Open(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) { + if(connect(socket, rp->ai_addr, rp->ai_addrlen) == 0 || + GetErrorCode() == SOCKERR(EINPROGRESS) || GetErrorCode() == SOCKERR(EWOULDBLOCK)) + break; + Close(); + } + rp = rp->ai_next; + if(!rp) { + SetSockError("connect", -1, "failed"); + return false; + } + } + mode = CONNECT; + return true; +} + + +bool TcpSocket::Connect(IpAddrInfo& info) +{ + LLOG("TCP Connect addrinfo"); + Init(); + Reset(); + addrinfo *result = info.GetResult(); + return RawConnect(result); +} + +bool TcpSocket::Connect(const char *host, int port) +{ + LLOG("TCP Connect(" << host << ':' << port << ')'); + + Init(); + Reset(); + IpAddrInfo info; + if(!info.Execute(host, port)) { + SetSockError(Format("getaddrinfo(%s) failed", host)); + return false; + } + return Connect(info); +} + +void TcpSocket::RawClose() +{ + LLOG("TCP close " << (int)socket); + if(socket != INVALID_SOCKET) { + int res; +#if defined(PLATFORM_WIN32) + res = closesocket(socket); +#elif defined(PLATFORM_POSIX) + res = close(socket); +#else + #error Unsupported platform +#endif + if(res && !IsError()) + SetSockError("close"); + socket = INVALID_SOCKET; + } +} + +void TcpSocket::Close() +{ + if(ssl) + ssl->Close(); + else + RawClose(); + ssl.Clear(); +} + +bool TcpSocket::WouldBlock() +{ + int c = GetErrorCode(); +#ifdef PLATFORM_POSIX + return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(EAGAIN); +#endif +#ifdef PLATFORM_WIN32 + return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(ENOTCONN); +#endif +} + +int TcpSocket::RawSend(const void *buf, int amount) +{ + int res = send(socket, (const char *)buf, amount, 0); + if(res < 0 && WouldBlock()) + res = 0; + else + if(res == 0 || res < 0) + SetSockError("send"); + return res; +} + +int TcpSocket::Send(const void *buf, int amount) +{ + if(SSLHandshake()) + return 0; + return ssl ? ssl->Send(buf, amount) : RawSend(buf, amount); +} + +void TcpSocket::Shutdown() +{ + ASSERT(IsOpen()); + if(shutdown(socket, SD_SEND)) + SetSockError("shutdown(SD_SEND)"); +} + +String TcpSocket::GetHostName() +{ + Init(); + char buffer[256]; + gethostname(buffer, __countof(buffer)); + return buffer; +} + +bool TcpSocket::RawWait(dword flags) +{ + LLOG("Wait(" << timeout << ", " << flags << ")"); + if((flags & WAIT_READ) && ptr != end) + return true; + int end_time = msecs() + timeout; + if(socket == INVALID_SOCKET) + return false; + for(;;) { + if(IsError() || IsAbort()) + return false; + int to = end_time - msecs(); + if(WhenWait) + to = waitstep; + timeval *tvalp = NULL; + timeval tval; + if(!IsNull(timeout) || WhenWait) { + to = max(to, 0); + tval.tv_sec = to / 1000; + tval.tv_usec = 1000 * (to % 1000); + tvalp = &tval; + } + fd_set fdset[1]; + FD_ZERO(fdset); + FD_SET(socket, fdset); + int avail = select((int)socket + 1, + flags & WAIT_READ ? fdset : NULL, + flags & WAIT_WRITE ? fdset : NULL, + flags & WAIT_EXCEPTION ? fdset : NULL, tvalp); + LLOG("Wait select avail: " << avail); + if(avail < 0) { + SetSockError("wait"); + return false; + } + if(avail > 0) + return true; + if(to <= 0 && timeout) { + return false; + } + WhenWait(); + if(timeout == 0) + return false; + } +} + +bool TcpSocket::Wait(dword flags) +{ + return ssl ? ssl->Wait(flags) : RawWait(flags); +} + +int TcpSocket::Put(const char *s, int length) +{ + LLOG("Put " << socket << ": " << length); + ASSERT(IsOpen()); + if(length < 0 && s) + length = (int)strlen(s); + if(!s || length <= 0 || IsError() || IsAbort()) + return 0; + done = 0; + bool peek = false; + while(done < length) { + if(peek && !WaitWrite()) + return done; + peek = false; + int count = Send(s + done, length - done); + if(IsError() || timeout == 0 && count == 0 && peek) + return done; + if(count > 0) + done += count; + else + peek = true; + } + LLOG("//Put() -> " << done); + return done; +} + +int TcpSocket::RawRecv(void *buf, int amount) +{ + int res = recv(socket, (char *)buf, amount, 0); + if(res == 0) + is_eof = true; + else + if(res < 0 && WouldBlock()) + res = 0; + else + if(res < 0) + SetSockError("recv"); + LLOG("recv(" << socket << "): " << res << " bytes: " + << AsCString((char *)buf, (char *)buf + min(res, 16)) + << (res ? "" : IsEof() ? ", EOF" : ", WOULDBLOCK")); + return res; +} + +int TcpSocket::Recv(void *buffer, int maxlen) +{ + if(SSLHandshake()) + return 0; + return ssl ? ssl->Recv(buffer, maxlen) : RawRecv(buffer, maxlen); +} + +void TcpSocket::ReadBuffer() +{ + ptr = end = buffer; + if(WaitRead()) + end = buffer + Recv(buffer, BUFFERSIZE); +} + +int TcpSocket::Get_() +{ + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return -1; + ReadBuffer(); + return ptr < end ? *ptr++ : -1; +} + +int TcpSocket::Peek_() +{ + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return -1; + ReadBuffer(); + return ptr < end ? *ptr : -1; +} + +int TcpSocket::Get(void *buffer, int count) +{ + LLOG("Get " << count); + + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return 0; + + String out; + int l = end - ptr; + done = 0; + if(l > 0) + if(l < count) { + memcpy(buffer, ptr, l); + done += l; + ptr = end; + } + else { + memcpy(buffer, ptr, count); + ptr += count; + return count; + } + while(done < count && !IsError() && !IsEof()) { + if(!WaitRead()) + break; + int part = Recv((char *)buffer + done, count - done); + if(part > 0) + done += part; + if(timeout == 0) + break; + } + return done; +} + +String TcpSocket::Get(int count) +{ + if(count == 0) + return Null; + StringBuffer out(count); + int done = Get(out, count); + if(!done && IsEof()) + return String::GetVoid(); + out.SetLength(done); + return out; +} + +String TcpSocket::GetAll(int len) +{ + String s = Get(len); + return s.GetCount() == len ? s : String::GetVoid(); +} + +String TcpSocket::GetLine(int maxlen) +{ + String ln; + for(;;) { + int c = Peek(); + if(c < 0) + return String::GetVoid(); + Get(); + if(c == '\n') + return ln; + if(c != '\r') + ln.Cat(c); + } +} + +void TcpSocket::SetSockError(const char *context, int code, const char *errdesc) +{ + errorcode = code; + errordesc.Clear(); + if(socket != INVALID_SOCKET) + errordesc << "socket(" << (int)socket << ") / "; + errordesc << context << ": " << errdesc; + is_error = true; + LLOG("TCP ERROR " << errordesc); +} + +void TcpSocket::SetSockError(const char *context, const char *errdesc) +{ + SetSockError(context, GetErrorCode(), errdesc); +} + +void TcpSocket::SetSockError(const char *context) +{ + SetSockError(context, TcpSocketErrorDesc(GetErrorCode())); +} + +TcpSocket::SSL *(*TcpSocket::CreateSSL)(TcpSocket& socket); + +bool TcpSocket::StartSSL() +{ + ASSERT(IsOpen()); + if(!CreateSSL) { + SetSockError("StartSSL", -1, "Missing SSL support (Core/SSL)"); + return false; + } + if(!IsOpen()) { + SetSockError("StartSSL", -1, "Socket is not open"); + return false; + } + if(!IsOpen()) { + SetSockError("StartSSL", -1, "Socket is not connected"); + return false; + } + ssl = (*CreateSSL)(*this); + if(!ssl->Start()) { + ssl.Clear(); + return false; + } + SSLHandshake(); + return true; +} + +bool TcpSocket::SSLHandshake() +{ + if(ssl && (mode == CONNECT || mode == ACCEPT)) { + dword w = ssl->Handshake(); + if(w) { + Wait(w); + return ssl->Handshake(); + } + } + return false; +} + +int SocketWaitEvent::Wait(int timeout) +{ + FD_ZERO(read); + FD_ZERO(write); + FD_ZERO(exception); + int maxindex = -1; + for(int i = 0; i < socket.GetCount(); i++) { + const Tuple2& s = socket[i]; + if(s.a >= 0) { + const Tuple2& s = socket[i]; + if(s.b & WAIT_READ) + FD_SET(s.a, read); + if(s.b & WAIT_WRITE) + FD_SET(s.a, write); + if(s.b & WAIT_EXCEPTION) + FD_SET(s.a, exception); + maxindex = max(s.a, maxindex); + } + } + timeval *tvalp = NULL; + timeval tval; + if(!IsNull(timeout)) { + tval.tv_sec = timeout / 1000; + tval.tv_usec = 1000 * (timeout % 1000); + tvalp = &tval; + } + return select(maxindex + 1, read, write, exception, tvalp); +} + +dword SocketWaitEvent::Get(int i) const +{ + int s = socket[i].a; + if(s < 0) + return 0; + dword events = 0; + if(FD_ISSET(s, read)) + events |= WAIT_READ; + if(FD_ISSET(s, write)) + events |= WAIT_WRITE; + if(FD_ISSET(s, exception)) + events |= WAIT_EXCEPTION; + return events; +} + +SocketWaitEvent::SocketWaitEvent() +{ + FD_ZERO(read); + FD_ZERO(write); + FD_ZERO(exception); +} + +END_UPP_NAMESPACE diff --git a/uppsrc/Web/SSL/httpscli.cpp b/uppsrc/Web/SSL/httpscli.cpp index b4ea99028..8c7e9a804 100644 --- a/uppsrc/Web/SSL/httpscli.cpp +++ b/uppsrc/Web/SSL/httpscli.cpp @@ -1,128 +1,128 @@ -#ifndef flagNOSSL - -#include "WebSSL.h" - -NAMESPACE_UPP - -extern bool HttpClient_Trace__; - -#ifdef _DEBUG -#define LLOG(x) if(HttpClient_Trace__) RLOG(x); else; -#else -#define LLOG(x) -#endif - -HttpsClient::HttpsClient() -{ - secure = true; -} - -bool HttpsClient::ProxyConnect() -{ - if(use_proxy) { - int start_time = msecs(); - int end_time = msecs() + timeout_msecs; - while(!socket.PeekWrite(1000)) { - int time = msecs(); - if(time >= end_time) { - error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); - Close(); - return false; - } - } - String host_port = host; - if(port) - host_port << ':' << port; - else - host_port << ":443"; - String request; - request << "CONNECT " << host_port << " HTTP/1.1\r\n" - << "Host: " << host_port << "\r\n"; - if(!IsNull(proxy_username)) - request << "Proxy-Authorization: Basic " - << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; - request << "\r\n"; - LLOG(request); - int written = 0; - while(msecs() < end_time) { - int nwrite = socket.WriteWait(request.GetIter(written), min(request.GetLength() - written, 1000), 1000); - if(socket.IsError()) { - error = Socket::GetErrorText(); - Close(); - return false; - } - if((written += nwrite) >= request.GetLength()) - break; - } - if(written < request.GetLength()) { - error = NFormat(t_("%s:%d: timed out sending request to server"), host, port); - Close(); - return false; - } - String line = ReadUntilProgress('\n', start_time, end_time, false); - LLOG("P< " << line); - if(socket.IsError()) { - error = Socket::GetErrorText(); - Close(); - return false; - } - if(!line.StartsWith("HTTP") || line.Find(" 2") < 0) { - error = "Invalid proxy reply: " + line; - Close(); - return false; - } - while(line.GetCount()) { - line = ReadUntilProgress('\n', start_time, end_time, false); - if(*line.Last() == '\r') - line.Trim(line.GetCount() - 1); - LLOG("P< " << line << " len " << line.GetCount()); - if(socket.IsError()) { - error = Socket::GetErrorText(); - Close(); - return false; - } - } - use_proxy = false; - while(!socket.PeekWrite(1000)) { - int time = msecs(); - if(time >= end_time) { - error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); - Close(); - return false; - } - } - } - return true; -} - -bool HttpsClient::IsSecure() -{ - return secure; -} - -bool HttpsClient::CreateClientSocket() -{ - if(!secure) - return HttpClient::CreateClientSocket(); - if(!ssl_context) { - ssl_context = new SSLContext; - if(!ssl_context->Create(const_cast(SSLv3_client_method()))) { - error = t_("Error creating SSL context."); - return false; - } - } - if(!SSLClientSocketUnsecured(socket, *ssl_context, socket_host, - socket_port ? socket_port : DEFAULT_HTTPS_PORT, true, NULL, 0, false)) { - error = Socket::GetErrorText(); - return false; - } - socket.Linger(0); - if(!ProxyConnect()) - return false; - SSLSecureSocket(socket); - return true; -} - -END_UPP_NAMESPACE - -#endif +#ifndef flagNOSSL + +#include "WebSSL.h" + +NAMESPACE_UPP + +extern bool HttpClient_Trace__; + +#ifdef _DEBUG +#define LLOG(x) if(HttpClient_Trace__) RLOG(x); else; +#else +#define LLOG(x) +#endif + +HttpsClient::HttpsClient() +{ + secure = true; +} + +bool HttpsClient::ProxyConnect() +{ + if(use_proxy) { + int start_time = msecs(); + int end_time = msecs() + timeout_msecs; + while(!socket.PeekWrite(1000)) { + int time = msecs(); + if(time >= end_time) { + error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); + Close(); + return false; + } + } + String host_port = host; + if(port) + host_port << ':' << port; + else + host_port << ":443"; + String request; + request << "CONNECT " << host_port << " HTTP/1.1\r\n" + << "Host: " << host_port << "\r\n"; + if(!IsNull(proxy_username)) + request << "Proxy-Authorization: Basic " + << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; + request << "\r\n"; + LLOG(request); + int written = 0; + while(msecs() < end_time) { + int nwrite = socket.WriteWait(request.GetIter(written), min(request.GetLength() - written, 1000), 1000); + if(socket.IsError()) { + error = Socket::GetErrorText(); + Close(); + return false; + } + if((written += nwrite) >= request.GetLength()) + break; + } + if(written < request.GetLength()) { + error = NFormat(t_("%s:%d: timed out sending request to server"), host, port); + Close(); + return false; + } + String line = ReadUntilProgress('\n', start_time, end_time, false); + LLOG("P< " << line); + if(socket.IsError()) { + error = Socket::GetErrorText(); + Close(); + return false; + } + if(!line.StartsWith("HTTP") || line.Find(" 2") < 0) { + error = "Invalid proxy reply: " + line; + Close(); + return false; + } + while(line.GetCount()) { + line = ReadUntilProgress('\n', start_time, end_time, false); + if(*line.Last() == '\r') + line.Trim(line.GetCount() - 1); + LLOG("P< " << line << " len " << line.GetCount()); + if(socket.IsError()) { + error = Socket::GetErrorText(); + Close(); + return false; + } + } + use_proxy = false; + while(!socket.PeekWrite(1000)) { + int time = msecs(); + if(time >= end_time) { + error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); + Close(); + return false; + } + } + } + return true; +} + +bool HttpsClient::IsSecure() +{ + return secure; +} + +bool HttpsClient::CreateClientSocket() +{ + if(!secure) + return HttpClient::CreateClientSocket(); + if(!ssl_context) { + ssl_context = new SSLContext; + if(!ssl_context->Create(const_cast(SSLv3_client_method()))) { + error = t_("Error creating SSL context."); + return false; + } + } + if(!SSLClientSocketUnsecured(socket, *ssl_context, socket_host, + socket_port ? socket_port : DEFAULT_HTTPS_PORT, true, NULL, 0, false)) { + error = Socket::GetErrorText(); + return false; + } + socket.Linger(0); + if(!ProxyConnect()) + return false; + SSLSecureSocket(socket); + return true; +} + +END_UPP_NAMESPACE + +#endif