import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
public class TeapotHttpServer {
private static final String TEAPOT_HTTP_RESP_STR =
"HTTP/1.1 418 I'm a teapot\r\n" +
"Content-Length: 130\r\n" +
"Content-Type: text/html\r\n" +
"Server: java-nio\r\n\r\n" +
"
418 I'm teapot418 I'm a teapot
";
private static final byte[] TEAPOT_HTTP_RESP_BYTES = TEAPOT_HTTP_RESP_STR.getBytes(StandardCharsets.UTF_8);
private static final String HOSTNAME = "localhost";
private static final int HTTP_PORT = 8080;
private static final Logger LOGGER = Logger.getLogger("TeapotHttpServer");
public static void main(String[] args) throws IOException {
var isShuttingDown = new AtomicBoolean(false);
addShutdownHook(isShuttingDown);
try (var serverSocketChannel = ServerSocketChannel.open();
var selector = Selector.open())
{
serverSocketChannel.configureBlocking(false); // use non-blocking io
// listen for ACCEPT events on server's socket channel
serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
serverSocketChannel.bind(new InetSocketAddress(HOSTNAME, HTTP_PORT));
LOGGER.info("Listen on http://" + HOSTNAME + ":" + HTTP_PORT);
while (!isShuttingDown.get()) {
try {
if (selector.select() < 1) { // selector.select() method is blocking method (waits for new events)
// no new events
// for examples in case of selector.wakeup() invocation
continue;
}
var selectedKeysIterator = selector.selectedKeys().iterator();
while (selectedKeysIterator.hasNext()) {
var selectionKey = selectedKeysIterator.next();
selectedKeysIterator.remove(); // do not forget to remove
if (selectionKey.isAcceptable()) {
// handle ACCEPT event on server's socket channel
handleAccept(selectionKey, selector);
} else if (selectionKey.isReadable()) {
// handle READ event on client's socket channel
handleRead(selectionKey);
} else if (selectionKey.isWritable()) {
// handle WRITE event on client's socket channel
handleWrite(selectionKey);
}
}
} catch (Exception e) {
LOGGER.log(Level.SEVERE, "Error in server loop", e);
}
}
}
}
private static void addShutdownHook(AtomicBoolean isShuttingDown) {
Runtime.getRuntime().addShutdownHook(new Thread(() -> isShuttingDown.set(true)));
}
private static void handleAccept(SelectionKey selectionKey, Selector selector) {
try {
SocketChannel socketChannel = ((ServerSocketChannel) selectionKey.channel()).accept();
if (socketChannel == null) {
return;
}
socketChannel.configureBlocking(false); // use non-blocking io for client's channel
// listen for READ events
socketChannel.register(selector, SelectionKey.OP_READ, new ChannelAttachment());
LOGGER.info("Connection from " + socketChannel.getRemoteAddress() + " accepted");
} catch (Exception e) {
LOGGER.log(Level.SEVERE, "Error in `handleAccept`", e);
}
}
private static void handleRead(SelectionKey selectionKey) {
try {
var socketChannel = (SocketChannel) selectionKey.channel();
var remoteAddress = socketChannel.getRemoteAddress();
var channelAttachment = (ChannelAttachment) selectionKey.attachment();
var httpMessageParser = (HttpMessageParser) channelAttachment.getMessageParser();
var reqByteBuf = ByteBuffer.allocate(256);
// Read request body from client's channel
int read;
while ((read = socketChannel.read(reqByteBuf)) > 0) {
httpMessageParser.consume(reqByteBuf);
}
if (read == -1) {
// seems we need to close channel
LOGGER.info("Got CLOSE from " + socketChannel.getRemoteAddress());
httpMessageParser.clearState();
selectionKey.attach(null);
selectionKey.cancel();
selectionKey.channel().close();
return;
}
if (!httpMessageParser.isFullyRead()) {
// http message was not fully read
// wait for next read cycle
return;
}
var method = httpMessageParser.getMethod();
var url = httpMessageParser.getUrl();
LOGGER.info("Got request from " + remoteAddress + ". Method: " + method + ", url: " + url);
httpMessageParser.clearState(); // do not forget to clear state if message was fully read
var respByteBuf = ByteBuffer.wrap(TEAPOT_HTTP_RESP_BYTES);
int remaining = respByteBuf.remaining();
int write = socketChannel.write(respByteBuf);
if (write < remaining) {
channelAttachment.setRespByteBuffer(respByteBuf);
// register also for WRITE events
// we will write remaining bytes when socket will be ready for writing
selectionKey.interestOpsOr(SelectionKey.OP_WRITE);
} else {
LOGGER.info("Teapot http response has been sent to " + remoteAddress);
}
} catch (Exception e) {
LOGGER.log(Level.SEVERE, "Error in `handleRead`", e);
}
}
private static void handleWrite(SelectionKey selectionKey) {
try {
SocketChannel socketChannel = (SocketChannel) selectionKey.channel();
var remoteAddress = socketChannel.getRemoteAddress();
var channelAttachment = (ChannelAttachment) selectionKey.attachment();
var respByteBuf = (ByteBuffer) channelAttachment.getRespByteBuffer();
int remaining = respByteBuf.remaining();
int write = socketChannel.write(respByteBuf);
if (write == remaining) {
// message has been fully sent to the client
LOGGER.info("Teapot http response has been sent to " + remoteAddress);
// we don't interested on WRITE events now
selectionKey.interestOps(SelectionKey.OP_READ);
channelAttachment.setRespByteBuffer(null);
}
} catch (Exception e) {
LOGGER.log(Level.SEVERE, "Error in `handleWrite`");
}
}
}
class ChannelAttachment {
// one instance per channel is enough
private final HttpMessageParser messageParser = new HttpMessageParser();
private ByteBuffer respByteBuffer;
public HttpMessageParser getMessageParser() {
return messageParser;
}
public ByteBuffer getRespByteBuffer() {
return respByteBuffer;
}
public void setRespByteBuffer(ByteBuffer respByteBuffer) {
this.respByteBuffer = null;
this.respByteBuffer = respByteBuffer;
}
}
// Only parses http messages without body
class HttpMessageParser {
private static final char[] END_MARK = new char[]{'\r', '\n', '\r', '\n'};
private final StringBuilder messageCollector = new StringBuilder();
private String url;
private String method;
private boolean fullyRead;
public void consume(ByteBuffer byteBuffer) {
byteBuffer.flip(); // switch to read-mode
String messagePart = StandardCharsets.UTF_8.decode(byteBuffer).toString();
this.messageCollector.append(messagePart);
if (checkIsFullyRead()) {
String fullHttpMessage = messageCollector.toString();
String[] lines = fullHttpMessage.split("\r\n");
String[] firstLineParts = lines[0].split("\\s");
this.method = firstLineParts[0];
this.url = firstLineParts[1];
this.fullyRead = true;
}
}
// check if reached \r\n\r\n char sequence
private boolean checkIsFullyRead() {
int len = messageCollector.length();
boolean matched = false;
for (int i = 0; i < END_MARK.length; i++) {
matched = messageCollector.charAt((len - 4) + i) == END_MARK[i];
if (!matched) {
break;
}
}
return matched;
}
public boolean isFullyRead() {
return this.fullyRead;
}
public String getUrl() {
return this.url;
}
public String getMethod() {
return this.method;
}
public void clearState() {
this.url = null;
this.method = null;
this.fullyRead = false;
this.messageCollector.delete(0, messageCollector.length());
}
}