Netty Websocket SSL

This is a small guide on how to create a Netty Websocket client/server application, communicating over SSL(wss). This guide showcases how to use JKS keystores/truststores, as they are the most common way of storing private keys and certificates in the Java world.
This guide will show:
  •  How to create a private key along with a self signed certificate using Java keytool
  •  How to create a truststore containing the self signed certificate. This certificate will be used by the websocket client for ‘trusting’ the websocket server upon SSL connection
  •  A simple Netty websocket server example, exposing an SSL connection, using the private key generated in the above step
  •  A simple Netty websocket client example, establishing an SSL connection to the server, using the JKS truststore created in the above step

Creating Our Keystore/Truststore

Java Keytool is a nice and easy to use utility, shipped with the JDK, for performing various cryptographic tasks (i.e. generating keys, generating and manipulating certificates etc). The official documentation is pretty easy to follow.
 
For our example, we need to generate a public/private key pair along with a self signed certificate. This can be done with the below command, the output of which is a JKS store, containing our private key and the self signed certificate.
 
keytool -genkeypair -alias TestKey -keyalg RSA -keysize 2048 -keystore TestKeystore.jks -storetype JKS
The above JKS keystore will be used by our Netty websocket server to perform the SSL handshake.
 
Once we have the keystore, we can actually extract the self signed certificate and import it into a JKS trustore. This truststore will be used by our Websocket client, determining which certificates to trust. If the client does not trust the certificate presented by the server the SSL handshake will not be successful.
 
The command to extract the certificate into a .cert file is:
 
keytool -exportcert -rfc -alias TestKey -keystore TestKeystore.jks -storepass changeit -storetype JKS -file TestCert.cert
 
And the command to import that exported certificate into a JKS truststore is:
keytool -importcert -file TestCert.cert -keystore TestTruststore.jks -storepass changeit -storetype JKS
 

Example Netty Application

Now that we have both the keystore (to be used by the Server) and the truststore (to be used by the client) we can create our demo Netty client/server applications.
 
Effectively, all we need to do is adding an SSLHandler in the ChannelPipeline. This SSLHandler needs to reference the SSL context created by the respective JKS keystore/truststore.

Server

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.security.KeyStore;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;

public class NettyWSServer {

    public void start() {

        final NioEventLoopGroup bossGroup = new NioEventLoopGroup(1);
        final NioEventLoopGroup worker = new NioEventLoopGroup(1);
        final ServerBootstrap wsServer = new ServerBootstrap()
            .group(bossGroup, worker)
            .channel(NioServerSocketChannel.class)
            .handler(new LoggingHandler(LogLevel.INFO))
            .childHandler(new ChannelInitializer<Channel>() {
                @Override
                protected void initChannel(final Channel channel) throws Exception {
                    ChannelPipeline pipeline = channel.pipeline();

                    pipeline.addLast(createSSLContext().newHandler(channel.alloc()));

                    pipeline.addLast(new HttpServerCodec());
                    pipeline.addLast(new HttpObjectAggregator(64_000));
                    pipeline.addLast(new WebSocketServerProtocolHandler("/"));

                    pipeline.addLast(new SimpleChannelInboundHandler<TextWebSocketFrame>() {

                        @Override
                        protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
                            System.out.println("Message=" + msg.text());
                            ctx.writeAndFlush(new TextWebSocketFrame(msg.text() + " back"));
                        }
                    });
                }
            });

        System.out.println("WS Server started");
        wsServer.bind(10_000)
            .channel().closeFuture().syncUninterruptibly();
    }

    private SslContext createSSLContext() throws Exception{
        KeyStore keystore = KeyStore.getInstance("JKS");
        keystore.load(NettyWSServer.class.getResourceAsStream("/TestKeystore.jks"), "changeit".toCharArray());

        KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
        keyManagerFactory.init(keystore, "changeit".toCharArray());

        SSLContext sslContext = SSLContext.getInstance("TLS");
        sslContext.init(keyManagerFactory.getKeyManagers(), null, null);

        return SslContextBuilder.forServer(keyManagerFactory).build();
    }

    public static void main(String[] args) {
        new NettyWSServer().start();
    }
}

Client

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker13;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContextBuilder;
import java.net.URI;
import java.security.KeyStore;
import java.util.Objects;
import javax.net.ssl.TrustManagerFactory;

public class NettyWSClient {

    public void start() {

        final EventLoopGroup bossLoop = new NioEventLoopGroup(1);
        Bootstrap client = new Bootstrap()
            .group(bossLoop)
            .channel(NioSocketChannel.class)
            .handler(new ChannelInitializer<NioSocketChannel>() {
                @Override
                protected void initChannel(NioSocketChannel channel) throws Exception {
                    ChannelPipeline pipeline = channel.pipeline();

                    KeyStore truststore = KeyStore.getInstance("JKS");
                    truststore.load(NettyWSClient.class.getResourceAsStream("/TestTruststore.jks"), "changeit".toCharArray());
                    TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
                    trustManagerFactory.init(truststore);

                    pipeline.addLast(SslContextBuilder.forClient().trustManager(trustManagerFactory).build().newHandler(channel.alloc()));

                    pipeline.addLast(new HttpClientCodec(512, 512, 512));
                    pipeline.addLast(new HttpObjectAggregator(16_384));
                    final String url = "wss://localhost:10000";
                    final WebSocketClientHandshaker13 wsHandshaker = new WebSocketClientHandshaker13(new URI(url),
                        WebSocketVersion.V13, "", false, new DefaultHttpHeaders(false), 64_000);
                    pipeline.addLast(new WebSocketClientProtocolHandler(wsHandshaker));

                    pipeline.addLast(new SimpleChannelInboundHandler<TextWebSocketFrame>() {

                        @Override
                        public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
                            if (evt instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent) {
                                WebSocketClientProtocolHandler.ClientHandshakeStateEvent handshakeStateEvent = (WebSocketClientProtocolHandler.ClientHandshakeStateEvent) evt;
                                switch (handshakeStateEvent) {
                                    case HANDSHAKE_COMPLETE:
                                        System.out.println("Handshake completed. Sending Hello World");
                                        ctx.writeAndFlush(new TextWebSocketFrame("Hello World"));
                                        break;
                                }
                            }
                        }

                        @Override
                        protected void channelRead0(final ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
                            System.out.println("Message=" + msg.text());
                        }
                    });
                }
            });
        client.connect("localhost", 10_000).channel().closeFuture().syncUninterruptibly();
    }

    public static void main(String[] args) {
        new NettyWSClient().start();
    }
}