Skip to content

Commit b02695e

Browse files
committed
test: add client test
1 parent b0d15e3 commit b02695e

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed

ktls/tests/client.rs

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//! Test: client connect to real world websites.
2+
3+
use core::num::NonZeroUsize;
4+
use core::time::Duration;
5+
use std::io;
6+
7+
use ktls::KtlsStream;
8+
use rustls::pki_types::ServerName;
9+
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
10+
use tokio::net::TcpStream;
11+
use tokio::time::timeout;
12+
13+
mod common {
14+
include!("../examples/common/mod.rs");
15+
}
16+
17+
#[test_case::test_matrix(
18+
[
19+
"www.google.com", // Google CDN
20+
"www.bing.com", // Azure CDN
21+
"github.com", // Azure CDN
22+
"www.baidu.com", // Baidu CDN
23+
"stackoverflow.com", // Cloudflare CDN
24+
"fastly.com", // Fastly CDN
25+
]
26+
)]
27+
#[tokio::test]
28+
async fn test_connecct_sites(server_name: &'static str) -> io::Result<()> {
29+
timeout(
30+
Duration::from_secs(10),
31+
test_connecct_sites_impl(server_name),
32+
)
33+
.await
34+
.unwrap_or_else(|e| {
35+
tracing::warn!("Test to {server_name} timed out?");
36+
37+
Err(io::Error::new(io::ErrorKind::TimedOut, e))
38+
})
39+
}
40+
41+
async fn test_connecct_sites_impl(server_name: &'static str) -> io::Result<()> {
42+
let _ = tracing_subscriber::fmt()
43+
.with_env_filter(tracing_subscriber::EnvFilter::new("TRACE"))
44+
.pretty()
45+
.try_init();
46+
47+
let Some(compatible_cipher_suites) = common::compatible_cipher_suites() else {
48+
return Ok(());
49+
};
50+
51+
let Ok(Ok(socket)) = timeout(
52+
Duration::from_secs(1),
53+
TcpStream::connect(format!("{server_name}:443")),
54+
)
55+
.await
56+
else {
57+
tracing::warn!("Failed to connect to {server_name}, skipped.");
58+
59+
return Ok(());
60+
};
61+
62+
let connector = common::client::get_ktls_connector(compatible_cipher_suites);
63+
64+
let mut ktls_stream = connector
65+
.try_connect(socket, ServerName::try_from(server_name).unwrap())
66+
.await
67+
.map_err(io::Error::other)?;
68+
69+
// Test 1
70+
tracing::info!("First request to {server_name}");
71+
http_request(&mut ktls_stream, server_name).await?;
72+
73+
// Test 2
74+
tracing::info!("Second request to {server_name}");
75+
http_request(&mut ktls_stream, server_name).await?;
76+
77+
Ok(())
78+
}
79+
80+
async fn http_request(
81+
ktls_stream: &mut KtlsStream<TcpStream>,
82+
server_name: &str,
83+
) -> io::Result<()> {
84+
// Write HTTP/1.1 request
85+
{
86+
ktls_stream
87+
.write_all(
88+
format!(
89+
"GET / HTTP/1.1\r\nHost: {}\r\nconnection: keep-alive\r\naccept-encoding: \
90+
identity\r\ntransfer-encoding: identity\r\n\r\n",
91+
server_name
92+
)
93+
.as_bytes(),
94+
)
95+
.await?;
96+
97+
tracing::debug!("Request sent to {server_name}");
98+
99+
// Read response
100+
let mut response = Vec::new();
101+
102+
let mut buf_stream = tokio::io::BufStream::new(ktls_stream);
103+
104+
let mut content_length = None;
105+
106+
loop {
107+
let total_has_read = response.len();
108+
109+
let has_read = buf_stream.read_until(b'\n', &mut response).await?;
110+
111+
if has_read == 0 || response.ends_with(b"\r\n\r\n") {
112+
break;
113+
}
114+
115+
let has_read_bytes = &response[total_has_read..];
116+
tracing::trace!(
117+
"Received from {server_name}: {}",
118+
String::from_utf8_lossy(has_read_bytes)
119+
);
120+
121+
const PREFIX: &[u8; 16] = b"content-length: ";
122+
123+
if has_read_bytes
124+
.get(..PREFIX.len())
125+
.map(|v| v.eq_ignore_ascii_case(PREFIX))
126+
== Some(true)
127+
{
128+
let v = std::str::from_utf8(&has_read_bytes[PREFIX.len()..])
129+
.expect("content length should be a number string")
130+
.trim()
131+
.parse::<usize>()
132+
.expect("content length should be a number");
133+
134+
content_length = Some(v);
135+
}
136+
}
137+
138+
// Read body
139+
{
140+
let Some(Some(content_length)) = content_length.map(NonZeroUsize::new) else {
141+
tracing::warn!("No body found in response from {server_name}, skipped.");
142+
143+
return Ok(());
144+
};
145+
146+
tracing::debug!(
147+
"Headers received from {server_name}, reading body ({content_length} bytes)..."
148+
);
149+
150+
response.reserve(content_length.get());
151+
152+
#[allow(unsafe_code)]
153+
// Safety: we have reserved enough space above.
154+
buf_stream
155+
.read_exact(unsafe {
156+
std::slice::from_raw_parts_mut(
157+
response.as_mut_ptr().add(response.len()),
158+
content_length.get(),
159+
)
160+
})
161+
.await?;
162+
163+
#[allow(unsafe_code)]
164+
// Safety: we just initialized the buffer above.
165+
unsafe {
166+
response.set_len(response.len() + content_length.get());
167+
}
168+
}
169+
170+
let response = String::from_utf8_lossy(&response);
171+
172+
tracing::info!("Got response from {server_name}");
173+
174+
tracing::trace!(
175+
"Response from {server_name}: {:#?} (...)",
176+
&response[..64.min(response.len())]
177+
);
178+
}
179+
180+
Ok(())
181+
}

0 commit comments

Comments
 (0)