6096783d64171943c170af292bb45f7550c4d9ed
[android/platform/packages/providers/DownloadProvider.git] / tests / src / tests / http / MockWebServer.java
1 /*
2  * Copyright (C) 2010 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package tests.http;
18
19 import android.text.TextUtils;
20 import android.util.Log;
21
22 import java.io.BufferedInputStream;
23 import java.io.BufferedOutputStream;
24 import java.io.ByteArrayOutputStream;
25 import java.io.IOException;
26 import java.io.InputStream;
27 import java.io.OutputStream;
28 import java.net.MalformedURLException;
29 import java.net.ServerSocket;
30 import java.net.Socket;
31 import java.net.URL;
32 import java.util.ArrayList;
33 import java.util.LinkedList;
34 import java.util.List;
35 import java.util.Queue;
36 import java.util.concurrent.BlockingQueue;
37 import java.util.concurrent.Callable;
38 import java.util.concurrent.ExecutionException;
39 import java.util.concurrent.ExecutorService;
40 import java.util.concurrent.Executors;
41 import java.util.concurrent.Future;
42 import java.util.concurrent.LinkedBlockingQueue;
43 import java.util.concurrent.TimeUnit;
44 import java.util.concurrent.TimeoutException;
45
46 /**
47  * A scriptable web server. Callers supply canned responses and the server
48  * replays them upon request in sequence.
49  *
50  * TODO: merge with the version from libcore/support/src/tests/java once it's in.
51  */
52 public final class MockWebServer {
53     static final String ASCII = "US-ASCII";
54
55     private final BlockingQueue<RecordedRequest> requestQueue
56             = new LinkedBlockingQueue<RecordedRequest>();
57     private final BlockingQueue<MockResponse> responseQueue
58             = new LinkedBlockingQueue<MockResponse>();
59     private int bodyLimit = Integer.MAX_VALUE;
60     private final ExecutorService executor = Executors.newCachedThreadPool();
61     // keep Futures around so we can rethrow any exceptions thrown by Callables
62     private final Queue<Future<?>> futures = new LinkedList<Future<?>>();
63
64     private int port = -1;
65     private ServerSocket serverSocket;
66
67     public int getPort() {
68         if (port == -1) {
69             throw new IllegalStateException("Cannot retrieve port before calling play()");
70         }
71         return port;
72     }
73
74     /**
75      * Returns a URL for connecting to this server.
76      *
77      * @param path the request path, such as "/".
78      */
79     public URL getUrl(String path) throws MalformedURLException {
80         return new URL("http://localhost:" + getPort() + path);
81     }
82
83     /**
84      * Sets the number of bytes of the POST body to keep in memory to the given
85      * limit.
86      */
87     public void setBodyLimit(int maxBodyLength) {
88         this.bodyLimit = maxBodyLength;
89     }
90
91     public void enqueue(MockResponse response) {
92         responseQueue.add(response);
93     }
94
95     /**
96      * Awaits the next HTTP request, removes it, and returns it. Callers should
97      * use this to verify the request sent was as intended.
98      */
99     public RecordedRequest takeRequest() throws InterruptedException {
100         return requestQueue.take();
101     }
102
103     public RecordedRequest takeRequestWithTimeout(long timeoutMillis) throws InterruptedException {
104         return requestQueue.poll(timeoutMillis, TimeUnit.MILLISECONDS);
105     }
106
107     public List<RecordedRequest> drainRequests() {
108         List<RecordedRequest> requests = new ArrayList<RecordedRequest>();
109         requestQueue.drainTo(requests);
110         return requests;
111     }
112
113     /**
114      * Starts the server, serves all enqueued requests, and shuts the server
115      * down.
116      */
117     public void play() throws IOException {
118         serverSocket = new ServerSocket(0);
119         serverSocket.setReuseAddress(true);
120         port = serverSocket.getLocalPort();
121         submitCallable(new Callable<Void>() {
122             public Void call() throws Exception {
123                 int count = 0;
124                 while (true) {
125                     if (count > 0 && responseQueue.isEmpty()) {
126                         serverSocket.close();
127                         executor.shutdown();
128                         return null;
129                     }
130
131                     serveConnection(serverSocket.accept());
132                     count++;
133                 }
134             }
135         });
136     }
137
138     /**
139      * shutdown the webserver
140      */
141     public void shutdown() throws IOException {
142         responseQueue.clear();
143         serverSocket.close();
144         executor.shutdown();
145     }
146
147     private void serveConnection(final Socket s) {
148         submitCallable(new Callable<Void>() {
149             public Void call() throws Exception {
150                 InputStream in = new BufferedInputStream(s.getInputStream());
151                 OutputStream out = new BufferedOutputStream(s.getOutputStream());
152
153                 int sequenceNumber = 0;
154                 while (true) {
155                     RecordedRequest request = readRequest(in, sequenceNumber);
156                     if (request == null) {
157                         if (sequenceNumber == 0) {
158                             throw new IllegalStateException("Connection without any request!");
159                         } else {
160                             break;
161                         }
162                     }
163                     requestQueue.add(request);
164                     MockResponse response = sendResponse(out, request);
165                     if (response.shouldCloseConnectionAfter()) {
166                         break;
167                     }
168                     sequenceNumber++;
169                 }
170
171                 in.close();
172                 out.close();
173                 return null;
174             }
175         });
176     }
177
178     private void submitCallable(Callable<?> callable) {
179         Future<?> future = executor.submit(callable);
180         futures.add(future);
181     }
182
183     /**
184      * Check for and raise any exceptions that have been thrown by child threads.  Will not block on
185      * children still running.
186      * @throws ExecutionException for the first child thread that threw an exception
187      */
188     public void checkForExceptions() throws ExecutionException, InterruptedException {
189         final int originalSize = futures.size();
190         for (int i = 0; i < originalSize; i++) {
191             Future<?> future = futures.remove();
192             try {
193                 future.get(0, TimeUnit.SECONDS);
194             } catch (TimeoutException e) {
195                 futures.add(future); // still running
196             }
197         }
198     }
199
200     /**
201      * @param sequenceNumber the index of this request on this connection.
202      */
203     private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
204         String request = readAsciiUntilCrlf(in);
205         if (request.equals("")) {
206             return null; // end of data; no more requests
207         }
208
209         List<String> headers = new ArrayList<String>();
210         int contentLength = -1;
211         boolean chunked = false;
212         String header;
213         while (!(header = readAsciiUntilCrlf(in)).equals("")) {
214             headers.add(header);
215             String lowercaseHeader = header.toLowerCase();
216             if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
217                 contentLength = Integer.parseInt(header.substring(15).trim());
218             }
219             if (lowercaseHeader.startsWith("transfer-encoding:") &&
220                     lowercaseHeader.substring(18).trim().equals("chunked")) {
221                 chunked = true;
222             }
223         }
224
225         boolean hasBody = false;
226         TruncatingOutputStream requestBody = new TruncatingOutputStream();
227         List<Integer> chunkSizes = new ArrayList<Integer>();
228         if (contentLength != -1) {
229             hasBody = true;
230             transfer(contentLength, in, requestBody);
231         } else if (chunked) {
232             hasBody = true;
233             while (true) {
234                 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
235                 if (chunkSize == 0) {
236                     readEmptyLine(in);
237                     break;
238                 }
239                 chunkSizes.add(chunkSize);
240                 transfer(chunkSize, in, requestBody);
241                 readEmptyLine(in);
242             }
243         }
244
245         if (request.startsWith("GET ")) {
246             if (hasBody) {
247                 throw new IllegalArgumentException("GET requests should not have a body!");
248             }
249         } else if (request.startsWith("POST ")) {
250             if (!hasBody) {
251                 throw new IllegalArgumentException("POST requests must have a body!");
252             }
253         } else {
254             throw new UnsupportedOperationException("Unexpected method: " + request);
255         }
256         return new RecordedRequest(request, headers, chunkSizes,
257                 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
258     }
259
260     /**
261      * Returns a response to satisfy {@code request}.
262      */
263     private MockResponse sendResponse(OutputStream out, RecordedRequest request)
264             throws InterruptedException, IOException {
265         if (responseQueue.isEmpty()) {
266             throw new IllegalStateException("Unexpected request: " + request);
267         }
268         MockResponse response = responseQueue.take();
269         writeResponse(out, response, false);
270         if (response.getNumPackets() > 0) {
271             // there are continuing packets to send as part of this response.
272             for (int i = 0; i < response.getNumPackets(); i++) {
273                 writeResponse(out, response, true);
274                 // delay sending next continuing response just a little bit
275                 Thread.sleep(100);
276             }
277         }
278         return response;
279      }
280
281     private void writeResponse(OutputStream out, MockResponse response,
282             boolean continuingPacket) throws IOException {
283         if (continuingPacket) {
284             // this is a continuing response - just send the body - no headers, status
285             out.write(response.getBody());
286             out.flush();
287             return;
288         }
289         out.write((response.getStatus() + "\r\n").getBytes(ASCII));
290         for (String header : response.getHeaders()) {
291             out.write((header + "\r\n").getBytes(ASCII));
292         }
293         out.write(("\r\n").getBytes(ASCII));
294         out.write(response.getBody());
295         out.flush();
296     }
297
298     /**
299      * Transfer bytes from {@code in} to {@code out} until either {@code length}
300      * bytes have been transferred or {@code in} is exhausted.
301      */
302     private void transfer(int length, InputStream in, OutputStream out) throws IOException {
303         byte[] buffer = new byte[1024];
304         while (length > 0) {
305             int count = in.read(buffer, 0, Math.min(buffer.length, length));
306             if (count == -1) {
307                 return;
308             }
309             out.write(buffer, 0, count);
310             length -= count;
311         }
312     }
313
314     /**
315      * Returns the text from {@code in} until the next "\r\n", or null if
316      * {@code in} is exhausted.
317      */
318     private String readAsciiUntilCrlf(InputStream in) throws IOException {
319         StringBuilder builder = new StringBuilder();
320         while (true) {
321             int c = in.read();
322             if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
323                 builder.deleteCharAt(builder.length() - 1);
324                 return builder.toString();
325             } else if (c == -1) {
326                 return builder.toString();
327             } else {
328                 builder.append((char) c);
329             }
330         }
331     }
332
333     private void readEmptyLine(InputStream in) throws IOException {
334         String line = readAsciiUntilCrlf(in);
335         if (!line.equals("")) {
336             throw new IllegalStateException("Expected empty but was: " + line);
337         }
338     }
339
340     /**
341      * An output stream that drops data after bodyLimit bytes.
342      */
343     private class TruncatingOutputStream extends ByteArrayOutputStream {
344         private int numBytesReceived = 0;
345         @Override public void write(byte[] buffer, int offset, int len) {
346             numBytesReceived += len;
347             super.write(buffer, offset, Math.min(len, bodyLimit - count));
348         }
349         @Override public void write(int oneByte) {
350             numBytesReceived++;
351             if (count < bodyLimit) {
352                 super.write(oneByte);
353             }
354         }
355     }
356 }