Skip to content

Commit 2d246ad

Browse files
authored
chore: better error reporting when connecting to tls with plain socket (#2740)
* chore: better error reporting when connecting to tls with plain socket --------- Signed-off-by: Roman Gershman <[email protected]>
1 parent 30c3f63 commit 2d246ad

File tree

4 files changed

+37
-31
lines changed

4 files changed

+37
-31
lines changed

src/facade/dragonfly_connection.cc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
#include <absl/container/flat_hash_map.h>
88
#include <absl/strings/match.h>
9+
#include <absl/strings/str_cat.h>
910
#include <mimalloc.h>
1011

1112
#include <numeric>
1213
#include <variant>
1314

14-
#include "absl/strings/str_cat.h"
1515
#include "base/flags.h"
1616
#include "base/io_buf.h"
1717
#include "base/logging.h"
@@ -613,10 +613,25 @@ void Connection::HandleRequests() {
613613
if (!(IsPrivileged() && no_tls_on_admin_port)) {
614614
// Must be done atomically before the premption point in Accept so that at any
615615
// point in time, the socket_ is defined.
616+
uint8_t buf[2];
617+
auto read_sz = socket_->Read(io::MutableBytes(buf));
618+
if (!read_sz || *read_sz < sizeof(buf)) {
619+
VLOG(1) << "Error reading from peer " << remote_ep << " " << read_sz.error().message();
620+
return;
621+
}
622+
if (buf[0] != 0x16 || buf[1] != 0x03) {
623+
VLOG(1) << "Bad TLS header "
624+
<< absl::StrCat(absl::Hex(buf[0], absl::kZeroPad2),
625+
absl::Hex(buf[1], absl::kZeroPad2));
626+
peer->Write(
627+
io::Buffer("-ERR Bad TLS header, double check "
628+
"if you enabled TLS for your client.\r\n"));
629+
}
630+
616631
{
617632
FiberAtomicGuard fg;
618633
unique_ptr<tls::TlsSocket> tls_sock = make_unique<tls::TlsSocket>(std::move(socket_));
619-
tls_sock->InitSSL(ssl_ctx_);
634+
tls_sock->InitSSL(ssl_ctx_, buf);
620635
SetSocket(tls_sock.release());
621636
}
622637
FiberSocketBase::AcceptResult aresult = socket_->Accept();

tests/dragonfly/connection_test.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import asyncio
55
import time
66
from redis import asyncio as aioredis
7-
from redis.exceptions import ConnectionError as redis_conn_error
7+
from redis.exceptions import ConnectionError as redis_conn_error, ResponseError
88
import async_timeout
99
from dataclasses import dataclass
1010

1111
from . import dfly_args
12-
from .instance import DflyInstance
12+
from .instance import DflyInstance, DflyInstanceFactory
1313

1414
BASE_PORT = 1111
1515

@@ -564,7 +564,7 @@ async def test_large_cmd(async_client: aioredis.Redis):
564564

565565
@pytest.mark.asyncio
566566
async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory):
567-
server = df_local_factory.create(
567+
server: DflyInstance = df_local_factory.create(
568568
no_tls_on_admin_port="true",
569569
admin_port=1111,
570570
port=1211,
@@ -573,13 +573,12 @@ async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_
573573
)
574574
server.start()
575575

576-
client = aioredis.Redis(port=server.port, password="XXX")
577-
try:
578-
await client.execute_command("DBSIZE")
579-
except redis_conn_error:
580-
pass
576+
client = server.client(password="XXX")
577+
with pytest.raises((ResponseError)):
578+
await client.dbsize()
579+
await client.close()
581580

582-
client = aioredis.Redis(port=server.admin_port, password="XXX")
581+
client = server.admin_client(password="XXX")
583582
assert await client.dbsize() == 0
584583
await client.close()
585584

@@ -605,27 +604,19 @@ async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, d
605604

606605

607606
@pytest.mark.asyncio
608-
async def test_tls_reject(with_ca_tls_server_args, with_tls_client_args, df_local_factory):
609-
server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
607+
async def test_tls_reject(
608+
with_ca_tls_server_args, with_tls_client_args, df_local_factory: DflyInstanceFactory
609+
):
610+
server: DflyInstance = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
610611
server.start()
611612

612-
client = aioredis.Redis(port=server.port, **with_tls_client_args, ssl_cert_reqs=None)
613-
try:
613+
client = server.client(**with_tls_client_args, ssl_cert_reqs=None)
614+
await client.ping()
615+
await client.close()
616+
617+
client = server.client(**with_tls_client_args)
618+
with pytest.raises(redis_conn_error):
614619
await client.ping()
615-
except redis_conn_error:
616-
pass
617-
618-
client = aioredis.Redis(port=server.port, **with_tls_client_args)
619-
try:
620-
assert await client.dbsize() != 0
621-
except redis_conn_error:
622-
pass
623-
624-
client = aioredis.Redis(port=server.port, ssl_cert_reqs=None)
625-
try:
626-
assert await client.dbsize() != 0
627-
except redis_conn_error:
628-
pass
629620
await client.close()
630621

631622

tests/dragonfly/tls_conf_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ async def test_config_enable_tls(
141141
await client.ping()
142142

143143
# Connecting without TLS should fail.
144-
with pytest.raises(redis.exceptions.ConnectionError):
144+
with pytest.raises(redis.exceptions.ResponseError):
145145
async with server.client() as client_unauth:
146146
await client_unauth.ping()
147147

0 commit comments

Comments
 (0)