stream_socket.cpp 3.1 KB
#include "socket-cpp/stream_socket.h"
#include "socket-cpp/exception.h"
#include <algorithm>
#include <memory>

using namespace std::chrono;
using namespace osdev::components::socket-cpp;

stream_socket stream_socket::create(int domain, int protocol /*=0*/)
{
	stream_socket sock(::socket(domain, COMM_TYPE, protocol));
	if (!sock)
		sock.clear(get_last_error());
	return sock;
}

ssize_t stream_socket::read(void *buf, size_t n)
{
	#if defined(_WIN32)
		return check_ret(::recv(handle(), reinterpret_cast<char*>(buf),
								int(n), 0));
	#else
		return check_ret(::recv(handle(), buf, n, 0));
	#endif
}

ssize_t stream_socket::read_n(void *buf, size_t n)
{
	size_t	nr = 0;
	ssize_t	nx = 0;

	uint8_t *b = reinterpret_cast<uint8_t*>(buf);

	while (nr < n) {
		if ((nx = read(b+nr, n-nr)) < 0 && last_error() == EINTR)
			continue;

		if (nx <= 0)
			break;

		nr += nx;
	}

	return (nr == 0 && nx < 0) ? nx : ssize_t(nr);
}


ssize_t stream_socket::read(const std::vector<iovec>& ranges)
{
	if (ranges.empty())
		return 0;

	#if !defined(_WIN32)
		return check_ret(::readv(handle(), ranges.data(), int(ranges.size())));
	#else
		std::vector<WSABUF> bufs;
		for (const auto& iovec : ranges) {
			bufs.push_back({
				static_cast<ULONG>(iovec.iov_len),
				static_cast<CHAR*>(iovec.iov_base)
			});
		}

		DWORD flags = 0,
			  nread = 0,
			  nbuf = DWORD(bufs.size());

		auto ret = check_ret(::WSARecv(handle(), bufs.data(), nbuf, &nread, &flags, nullptr, nullptr));
		return ssize_t(ret == SOCKET_ERROR ? ret : nread);
	#endif
}

bool stream_socket::read_timeout(const microseconds& to)
{
	auto tv = 
		#if defined(_WIN32)
			DWORD(duration_cast<milliseconds>(to).count());
		#else
			to_timeval(to);
		#endif
	return set_option(SOL_SOCKET, SO_RCVTIMEO, tv);
}

ssize_t stream_socket::write(const void *buf, size_t n)
{
	#if defined(_WIN32)
		return check_ret(::send(handle(), reinterpret_cast<const char*>(buf),
								int(n) , 0));
	#else
		return check_ret(::send(handle(), buf, n , 0));
	#endif
}

ssize_t stream_socket::write_n(const void *buf, size_t n)
{
	size_t	nw = 0;
	ssize_t	nx = 0;

	const uint8_t *b = reinterpret_cast<const uint8_t*>(buf);

	while (nw < n) {
		if ((nx = write(b+nw, n-nw)) < 0 && last_error() == EINTR)
			continue;

		if (nx <= 0)
			break;

		nw += nx;
	}

	return (nw == 0 && nx < 0) ? nx : ssize_t(nw);
}

ssize_t stream_socket::write(const std::vector<iovec>& ranges)
{
	if (ranges.empty())
		return 0;

	#if !defined(_WIN32)
		return check_ret(::writev(handle(), ranges.data(), int(ranges.size())));
	#else
		std::vector<WSABUF> bufs;
		for (const auto& iovec : ranges) {
			bufs.push_back({
				static_cast<ULONG>(iovec.iov_len),
				static_cast<CHAR*>(iovec.iov_base)
			});
		}

		DWORD nwritten = 0,
			nmsg = DWORD(bufs.size());

		auto ret = check_ret(::WSASend(handle(), bufs.data(), nmsg, &nwritten, 0, nullptr, nullptr));
		return ssize_t(ret == SOCKET_ERROR ? ret : nwritten);
	#endif
}

bool stream_socket::write_timeout(const microseconds& to)
{
	auto tv = 
		#if defined(_WIN32)
			DWORD(duration_cast<milliseconds>(to).count());
		#else
			to_timeval(to);
		#endif

	return set_option(SOL_SOCKET, SO_SNDTIMEO, tv);
}