1 package ca.uhn.hl7v2.hoh.sockets;
2
3 import java.io.IOException;
4 import java.io.InputStream;
5 import java.net.InetSocketAddress;
6 import java.net.ServerSocket;
7 import java.net.Socket;
8 import java.util.Arrays;
9 import javax.net.ssl.SSLHandshakeException;
10 import javax.net.ssl.SSLServerSocket;
11
12 import ca.uhn.hl7v2.hoh.util.RandomServerPortProvider;
13 import org.junit.Before;
14 import org.junit.Test;
15 import org.mortbay.jetty.Server;
16 import org.mortbay.jetty.security.SslSelectChannelConnector;
17
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.fail;
20
21 public class CustomCertificateTlsSocketFactoryTest {
22
23 private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(CustomCertificateTlsSocketFactoryTest.class);
24
25 private int myPort;
26
27 @Before
28 public void before() {
29 myPort = RandomServerPortProvider.findFreePort();
30 }
31
32 @Test
33 public void testConnectToNonTrustedSocket() throws IOException, InterruptedException {
34
35 CustomCertificateTlsSocketFactory badServer = createTrustedServerSocketFactory();
36 Receiver receiver = new Receiver(badServer);
37 receiver.start();
38 Thread.sleep(500);
39
40 try {
41
42 CustomCertificateTlsSocketFactory goodClient = createNonTrustedClientSocketFactory();
43 Socket client = goodClient.createClientSocket();
44 client.connect(new InetSocketAddress("localhost", myPort));
45
46 client.getOutputStream().write("HELLO WORLD".getBytes());
47 fail();
48
49 } catch (SSLHandshakeException e) {
50
51 }
52 }
53
54 @Test
55 public void testConnectToTrustedSocket() throws IOException, InterruptedException {
56
57 CustomCertificateTlsSocketFactory goodServer = createTrustedServerSocketFactory();
58 Receiver receiver = new Receiver(goodServer);
59 receiver.start();
60 Thread.sleep(500);
61
62 CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory();
63 goodClient.setKeystoreFilename("src/test/resources/truststore.jks");
64
65 Socket client = goodClient.createClientSocket();
66 client.connect(new InetSocketAddress("localhost", myPort));
67
68 client.getOutputStream().write("HELLO WORLD".getBytes());
69 client.close();
70
71 Thread.sleep(500);
72 String expected = "HELLO WORLD";
73 String actual = receiver.myString;
74 assertEquals(expected, actual);
75
76 }
77
78 public static CustomCertificateTlsSocketFactory createNonTrustedClientSocketFactory() {
79 CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory();
80 goodClient.setKeystoreFilename("src/test/resources/truststore2.jks");
81 goodClient.setKeystorePassphrase("trustpassword");
82 return goodClient;
83 }
84
85 public static StandardSocketFactory createNonSslServerSocketFactory() {
86 StandardSocketFactory goodClient = new StandardSocketFactory();
87 return goodClient;
88 }
89
90 public static CustomCertificateTlsSocketFactory createTrustedClientSocketFactory() {
91 CustomCertificateTlsSocketFactory goodClient = new CustomCertificateTlsSocketFactory();
92 goodClient.setKeystoreFilename("src/test/resources/truststore.jks");
93
94 return goodClient;
95 }
96
97 public static CustomCertificateTlsSocketFactory createTrustedServerSocketFactory() {
98 CustomCertificateTlsSocketFactory goodServer = new CustomCertificateTlsSocketFactory();
99 goodServer.setKeystoreFilename("src/test/resources/keystore.jks");
100 goodServer.setKeystorePassphrase("changeit");
101 return goodServer;
102 }
103
104 public static void main(String[] args) throws Exception {
105
106 Server s = new Server();
107
108 SslSelectChannelConnector ssl = new SslSelectChannelConnector();
109 ssl.setKeystore("src/test/resources/keystore.jks");
110 ssl.setPassword("changeit");
111 ssl.setKeyPassword("changeit");
112 ssl.setPort(60647);
113
114 s.addConnector(ssl);
115 s.start();
116 }
117
118 private class Receiver extends Thread {
119
120 private ISocketFactory myFactory;
121 private ServerSocket myServer;
122 private String myString;
123
124 public Receiver(ISocketFactory theFactory) {
125 myFactory = theFactory;
126 }
127
128 @Override
129 public void run() {
130 try {
131
132 ourLog.info("Listening on port {}", myPort);
133
134 myServer = myFactory.createServerSocket();
135 myServer.bind(new InetSocketAddress(myPort));
136 myServer.setSoTimeout(3000);
137
138 if (myServer instanceof SSLServerSocket) {
139 SSLServerSocket ss = (SSLServerSocket) myServer;
140 ourLog.info(Arrays.asList(ss.getEnabledCipherSuites()).toString());
141 }
142
143 Socket socket = myServer.accept();
144 socket.setSoTimeout(2000);
145
146 InputStream is = socket.getInputStream();
147 StringBuilder b = new StringBuilder();
148 for (;;) {
149 int next = is.read();
150 if (next == -1) {
151 break;
152 } else {
153 b.append((char) next);
154 ourLog.info("Received: " + b);
155 }
156 }
157
158 myString = b.toString();
159 } catch (Throwable e) {
160 ourLog.error("Failed", e);
161 fail(e.getMessage());
162 } finally {
163 if (myServer != null) {
164 try {
165 myServer.close();
166 } catch (Exception e) {
167 e.printStackTrace();
168 }
169 }
170 }
171 }
172
173 }
174
175 }