Skip to content

Commit 9b6ee18

Browse files
authored
chore(string_family): Refactor SetCmd (#2919)
* chore(string_family): Refactor SetCmd --------- Signed-off-by: Vladislav Oleshko <[email protected]>
1 parent e352edd commit 9b6ee18

File tree

3 files changed

+107
-158
lines changed

3 files changed

+107
-158
lines changed

src/server/string_family.cc

Lines changed: 92 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
#include "base/logging.h"
1818
#include "base/stl_util.h"
1919
#include "facade/cmd_arg_parser.h"
20+
#include "facade/op_status.h"
21+
#include "redis/redis_aux.h"
2022
#include "server/acl/acl_commands_def.h"
2123
#include "server/command_registry.h"
2224
#include "server/conn_context.h"
2325
#include "server/engine_shard_set.h"
2426
#include "server/error.h"
2527
#include "server/journal/journal.h"
28+
#include "server/table.h"
2629
#include "server/tiered_storage.h"
2730
#include "server/transaction.h"
2831

@@ -321,8 +324,7 @@ void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) {
321324
size_t i = 0;
322325
for (; i < args.size(); i += 2) {
323326
DVLOG(1) << "MSet " << args[i] << ":" << args[i + 1];
324-
OpResult<optional<string>> res = sg.Set(params, args[i], args[i + 1]);
325-
if (res.status() != OpStatus::OK) { // OOM for example.
327+
if (sg.Set(params, args[i], args[i + 1]) != OpStatus::OK) { // OOM for example.
326328
success->store(false);
327329
break;
328330
}
@@ -346,18 +348,6 @@ void OpMSet(const OpArgs& op_args, ArgSlice args, atomic_bool* success) {
346348
}
347349
}
348350

349-
// See comment for SetCmd::Set() for when and how OpResult's value (i.e. optional<string>) is set.
350-
OpResult<optional<string>> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& sparams,
351-
string_view key, string_view value, bool manual_journal) {
352-
DCHECK(cntx->transaction);
353-
354-
auto cb = [&](Transaction* t, EngineShard* shard) {
355-
SetCmd sg(t->GetOpArgs(shard), manual_journal);
356-
return sg.Set(sparams, key, value);
357-
};
358-
return cntx->transaction->ScheduleSingleHopT(std::move(cb));
359-
}
360-
361351
// emission_interval_ms assumed to be positive
362352
// limit is assumed to be positive
363353
OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view key,
@@ -464,32 +454,6 @@ OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view
464454
return array<int64_t, 5>{limited ? 1 : 0, limit, remaining, retry_after_ms, reset_after_ms};
465455
}
466456

467-
class SetResultBuilder {
468-
public:
469-
explicit SetResultBuilder(bool return_prev_value) : return_prev_value_(return_prev_value) {
470-
}
471-
472-
void CachePrevValueIfNeeded(const PrimeValue& pv) {
473-
if (return_prev_value_) {
474-
// We call lazily call GetString() here to save string copying when not needed.
475-
prev_value_ = GetString(pv);
476-
}
477-
}
478-
479-
// Returns either the previous value or `status`, depending on return_prev_value_.
480-
OpResult<optional<string>> Return(OpStatus status) && {
481-
if (return_prev_value_) {
482-
return std::move(prev_value_);
483-
} else {
484-
return status;
485-
}
486-
}
487-
488-
private:
489-
bool return_prev_value_;
490-
std::optional<string> prev_value_;
491-
};
492-
493457
SinkReplyBuilder::MGetResponse OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction* t,
494458
EngineShard* shard) {
495459
auto keys = t->GetShardArgs(shard->shard_id());
@@ -561,119 +525,52 @@ struct StringValue {
561525

562526
} // namespace
563527

564-
OpResult<optional<string>> SetCmd::Set(const SetParams& params, string_view key,
565-
string_view value) {
566-
bool fetch_val = params.flags & SET_GET;
567-
SetResultBuilder result_builder(fetch_val);
568-
569-
EngineShard* shard = op_args_.shard;
570-
auto& db_slice = shard->db_slice();
528+
OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) {
529+
auto& db_slice = op_args_.shard->db_slice();
571530

572531
DCHECK(db_slice.IsDbValid(op_args_.db_cntx.db_index));
573-
574532
VLOG(2) << "Set " << key << "(" << db_slice.shard_id() << ") ";
575533

576-
// if SET_GET is not set then prev_val is null.
577-
DCHECK(fetch_val || params.prev_val == nullptr);
578-
579534
if (params.IsConditionalSet()) {
580-
// We do not always set prev_val and we use result_builder for that.
581-
bool fetch_value = params.prev_val || fetch_val;
582-
DbSlice::ItAndUpdater find_res;
583-
if (fetch_value) {
584-
find_res = db_slice.FindAndFetchMutable(op_args_.db_cntx, key);
585-
} else {
586-
find_res = db_slice.FindMutable(op_args_.db_cntx, key);
587-
}
588-
589-
if (IsValid(find_res.it)) {
590-
if (find_res.it->second.ObjType() != OBJ_STRING) {
591-
return OpStatus::WRONG_TYPE;
592-
}
593-
result_builder.CachePrevValueIfNeeded(find_res.it->second);
594-
}
535+
auto find_res = db_slice.FindMutable(op_args_.db_cntx, key);
536+
if (auto status = CachePrevIfNeeded(params, find_res.it); status != OpStatus::OK)
537+
return status;
595538

596-
// Make sure that we have this key, and only add it if it does exists
597539
if (params.flags & SET_IF_EXISTS) {
598540
if (IsValid(find_res.it)) {
599-
return std::move(result_builder)
600-
.Return(SetExisting(params, find_res.it, find_res.exp_it, key, value));
541+
return SetExisting(params, find_res.it, find_res.exp_it, key, value);
601542
} else {
602-
return std::move(result_builder).Return(OpStatus::SKIPPED);
543+
return OpStatus::SKIPPED;
603544
}
604545
} else {
605-
if (IsValid(find_res.it)) { // if the policy is not to overide and have the key, just return
606-
return std::move(result_builder).Return(OpStatus::SKIPPED);
607-
}
546+
DCHECK(params.flags & SET_IF_NOTEXIST) << params.flags;
547+
if (IsValid(find_res.it)) {
548+
return OpStatus::SKIPPED;
549+
} // else AddNew() is called below
608550
}
609551
}
610552

611-
// At this point we either need to add missing entry, or we
612-
// will override an existing one
613-
// Trying to add a new entry.
614553
auto op_res = db_slice.AddOrFind(op_args_.db_cntx, key);
615554
RETURN_ON_BAD_STATUS(op_res);
616-
auto& add_res = *op_res;
617-
618-
auto it = add_res.it;
619-
if (!add_res.is_new) {
620-
if (fetch_val && it->second.ObjType() != OBJ_STRING) {
621-
return OpStatus::WRONG_TYPE;
622-
}
623-
result_builder.CachePrevValueIfNeeded(it->second);
624-
return std::move(result_builder).Return(SetExisting(params, it, add_res.exp_it, key, value));
625-
}
626-
627-
// Adding new value.
628-
PrimeValue tvalue{value};
629-
tvalue.SetFlag(params.memcache_flags != 0);
630-
it->second = std::move(tvalue);
631-
632-
if (params.expire_after_ms) {
633-
db_slice.AddExpire(op_args_.db_cntx.db_index, it,
634-
params.expire_after_ms + op_args_.db_cntx.time_now_ms);
635-
}
636-
637-
if (params.memcache_flags)
638-
db_slice.SetMCFlag(op_args_.db_cntx.db_index, it->first.AsRef(), params.memcache_flags);
639-
640-
if (params.flags & SET_STICK) {
641-
it->first.SetSticky(true);
642-
}
643-
644-
if (shard->tiered_storage() &&
645-
TieredStorage::EligibleForOffload(value.size())) { // external storage enabled.
646-
shard->tiered_storage()->ScheduleOffloadWithThrottle(op_args_.db_cntx.db_index, it.GetInnerIt(),
647-
key);
648-
}
649555

650-
if (shard->tiered_storage_v2()) { // external storage enabled
651-
shard->tiered_storage_v2()->Stash(key, &it->second);
652-
}
556+
if (!op_res->is_new) {
557+
if (auto status = CachePrevIfNeeded(params, op_res->it); status != OpStatus::OK)
558+
return status;
653559

654-
if (manual_journal_ && op_args_.shard->journal()) {
655-
RecordJournal(params, key, value);
560+
return SetExisting(params, op_res->it, op_res->exp_it, key, value);
561+
} else {
562+
AddNew(params, op_res->it, op_res->exp_it, key, value);
563+
return OpStatus::OK;
656564
}
657-
658-
return std::move(result_builder).Return(OpStatus::OK);
659565
}
660566

661567
OpStatus SetCmd::SetExisting(const SetParams& params, DbSlice::Iterator it,
662568
DbSlice::ExpIterator e_it, string_view key, string_view value) {
663-
if (params.flags & SET_IF_NOTEXIST)
664-
return OpStatus::SKIPPED;
569+
DCHECK_EQ(params.flags & SET_IF_NOTEXIST, 0);
665570

666571
PrimeValue& prime_value = it->second;
667572
EngineShard* shard = op_args_.shard;
668573

669-
if (params.prev_val) {
670-
if (prime_value.ObjType() != OBJ_STRING)
671-
return OpStatus::WRONG_TYPE;
672-
673-
string val = GetString(prime_value);
674-
params.prev_val->emplace(std::move(val));
675-
}
676-
677574
DbSlice& db_slice = shard->db_slice();
678575
uint64_t at_ms =
679576
params.expire_after_ms ? params.expire_after_ms + op_args_.db_cntx.time_now_ms : 0;
@@ -700,7 +597,6 @@ OpStatus SetCmd::SetExisting(const SetParams& params, DbSlice::Iterator it,
700597
prime_value.SetFlag(params.memcache_flags != 0);
701598
db_slice.SetMCFlag(op_args_.db_cntx.db_index, it->first.AsRef(), params.memcache_flags);
702599

703-
db_slice.RemoveFromTiered(it, op_args_.db_cntx.db_index);
704600
// overwrite existing entry.
705601
prime_value.SetString(value);
706602
DCHECK(!prime_value.HasIoPending());
@@ -712,6 +608,43 @@ OpStatus SetCmd::SetExisting(const SetParams& params, DbSlice::Iterator it,
712608
return OpStatus::OK;
713609
}
714610

611+
void SetCmd::AddNew(const SetParams& params, DbSlice::Iterator it, DbSlice::ExpIterator e_it,
612+
std::string_view key, std::string_view value) {
613+
EngineShard* shard = op_args_.shard;
614+
auto& db_slice = shard->db_slice();
615+
616+
// Adding new value.
617+
PrimeValue tvalue{value};
618+
tvalue.SetFlag(params.memcache_flags != 0);
619+
it->second = std::move(tvalue);
620+
621+
if (params.expire_after_ms) {
622+
db_slice.AddExpire(op_args_.db_cntx.db_index, it,
623+
params.expire_after_ms + op_args_.db_cntx.time_now_ms);
624+
}
625+
626+
if (params.memcache_flags)
627+
db_slice.SetMCFlag(op_args_.db_cntx.db_index, it->first.AsRef(), params.memcache_flags);
628+
629+
if (params.flags & SET_STICK) {
630+
it->first.SetSticky(true);
631+
}
632+
633+
if (shard->tiered_storage() &&
634+
TieredStorage::EligibleForOffload(value.size())) { // external storage enabled.
635+
shard->tiered_storage()->ScheduleOffloadWithThrottle(op_args_.db_cntx.db_index, it.GetInnerIt(),
636+
key);
637+
}
638+
639+
if (shard->tiered_storage_v2()) { // external storage enabled
640+
shard->tiered_storage_v2()->Stash(key, &it->second);
641+
}
642+
643+
if (manual_journal_ && op_args_.shard->journal()) {
644+
RecordJournal(params, key, value);
645+
}
646+
}
647+
715648
void SetCmd::RecordJournal(const SetParams& params, string_view key, string_view value) {
716649
absl::InlinedVector<string_view, 5> cmds({key, value}); // 5 is theoretical maximum;
717650

@@ -737,6 +670,27 @@ void SetCmd::RecordJournal(const SetParams& params, string_view key, string_view
737670
dfly::RecordJournal(op_args_, "SET", ArgSlice{cmds});
738671
}
739672

673+
OpStatus SetCmd::CachePrevIfNeeded(const SetCmd::SetParams& params, DbSlice::Iterator it) {
674+
if (!params.prev_val || !IsValid(it))
675+
return OpStatus::OK;
676+
if (it->second.ObjType() != OBJ_STRING)
677+
return OpStatus::WRONG_TYPE;
678+
679+
*params.prev_val = GetString(it->second);
680+
return OpStatus::OK;
681+
}
682+
683+
// Wrapper to call SetCmd::Set in ScheduleSingleHop
684+
OpStatus SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& sparams, string_view key,
685+
string_view value) {
686+
DCHECK(cntx->transaction);
687+
688+
bool manual_journal = cntx->cid->opt_mask() & CO::NO_AUTOJOURNAL;
689+
return cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) {
690+
return SetCmd(t->GetOpArgs(shard), manual_journal).Set(sparams, key, value);
691+
});
692+
}
693+
740694
void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
741695
facade::CmdArgParser parser{args};
742696

@@ -812,7 +766,11 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
812766
return builder->SendError(kSyntaxErr);
813767
}
814768

815-
OpResult result{SetGeneric(cntx, sparams, key, value, true)};
769+
optional<string> prev;
770+
if (sparams.flags & SetCmd::SET_GET)
771+
sparams.prev_val = &prev;
772+
773+
OpStatus result = SetGeneric(cntx, sparams, key, value);
816774

817775
if (result == OpStatus::WRONG_TYPE) {
818776
return cntx->SendError(kWrongTypeErr);
@@ -821,8 +779,8 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
821779
if (sparams.flags & SetCmd::SET_GET) {
822780
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
823781
// When SET_GET is used, the reply is not affected by whether anything was set.
824-
if (result->has_value()) {
825-
rb->SendBulkString(result->value());
782+
if (prev.has_value()) {
783+
rb->SendBulkString(*prev);
826784
} else {
827785
rb->SendNull();
828786
}
@@ -861,7 +819,8 @@ void StringFamily::SetNx(CmdArgList args, ConnectionContext* cntx) {
861819
SetCmd::SetParams sparams;
862820
sparams.flags |= SetCmd::SET_IF_NOTEXIST;
863821
sparams.memcache_flags = cntx->conn_state.memcache_flag;
864-
const auto results{SetGeneric(cntx, std::move(sparams), key, value, false)};
822+
const auto results{SetGeneric(cntx, sparams, key, value)};
823+
865824
SinkReplyBuilder* builder = cntx->reply_builder();
866825
if (results == OpStatus::OK) {
867826
return builder->SendLong(1); // this means that we successfully set the value
@@ -939,13 +898,7 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
939898

940899
SetCmd::SetParams sparams;
941900
sparams.prev_val = &prev_val;
942-
943-
auto cb = [&](Transaction* t, EngineShard* shard) {
944-
SetCmd cmd(t->GetOpArgs(shard), false);
945-
946-
return cmd.Set(sparams, key, value).status();
947-
};
948-
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
901+
OpStatus status = SetGeneric(cntx, sparams, key, value);
949902

950903
if (status != OpStatus::OK) {
951904
cntx->SendError(status);
@@ -1201,14 +1154,7 @@ void StringFamily::SetExGeneric(bool seconds, CmdArgList args, ConnectionContext
12011154
sparams.expire_after_ms = unit_vals;
12021155
}
12031156

1204-
auto cb = [&](Transaction* t, EngineShard* shard) {
1205-
SetCmd sg(t->GetOpArgs(shard), true);
1206-
return sg.Set(sparams, key, value).status();
1207-
};
1208-
1209-
OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb));
1210-
1211-
return cntx->SendError(result.status());
1157+
cntx->SendError(SetGeneric(cntx, sparams, key, value));
12121158
}
12131159

12141160
void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) {

0 commit comments

Comments
 (0)