Skip to content

Commit

Permalink
FST works (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzhichang authored Jan 27, 2024
1 parent 6b1ba2f commit 9cea3eb
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 134 deletions.
15 changes: 8 additions & 7 deletions src/storage/invertedindex/fst/build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ void FstBuilder::InsertOutput(u8 *bs_ptr, SizeT bs_len, u64 val) {
void FstBuilder::CompileFrom(SizeT istate) {
SizeT addr = NONE_ADDRESS;
while (istate + 1 < unfinished_.Len()) {
UniquePtr<BuilderNode> node;
if (addr == NONE_ADDRESS)
node = unfinished_.PopEmpty();
else
node = unfinished_.PopFreeze(addr);
addr = Compile(*node);
if (addr == NONE_ADDRESS) {
UniquePtr<BuilderNode> node = unfinished_.PopEmpty();
addr = Compile(*node);
} else {
UniquePtr<BuilderNode> node = unfinished_.PopFreeze(addr);
addr = Compile(*node);
}
assert(addr != NONE_ADDRESS);
}
unfinished_.TopLastFreeze(addr);
Expand All @@ -69,7 +70,7 @@ CompiledAddr FstBuilder::Compile(BuilderNode &node) {
CompiledAddr start_addr = wtr_.Count();
Node::Compile(wtr_, last_addr_, start_addr, node);
last_addr_ = wtr_.Count() - 1;
registry_.Insert(node, start_addr);
registry_.Insert(node, last_addr_);
return last_addr_;
}

Expand Down
27 changes: 17 additions & 10 deletions src/storage/invertedindex/fst/build.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ struct BuilderNodeUnfinished {
Optional<LastTransition> last_;

void LastCompiled(CompiledAddr addr) {
assert(last_.has_value());
node_->trans_.push_back(Transition{last_->inp_, last_->out_, addr});
if (last_.has_value()) {
node_->trans_.push_back(Transition{last_->inp_, last_->out_, addr});
last_.reset();
}
}

void AddOutputPrefix(Output prefix) {
Expand All @@ -41,6 +43,11 @@ struct BuilderNodeUnfinished {
struct UnfinishedNodes {
Vector<UniquePtr<BuilderNodeUnfinished>> stack_;

UnfinishedNodes() {
stack_.reserve(64);
PushEmpty(false);
}

SizeT Len() { return stack_.size(); }

void PushEmpty(bool is_final) {
Expand Down Expand Up @@ -90,8 +97,7 @@ struct UnfinishedNodes {
assert(!unfinished->last_.has_value());
unfinished->last_ = LastTransition{bs_ptr[0], out};
for (SizeT i = 1; i < bs_len; i++) {
auto node = MakeUnique<BuilderNode>();
auto unfinished = MakeUnique<BuilderNodeUnfinished>(std::move(node), LastTransition{bs_ptr[i], Output::Zero()});
auto unfinished = MakeUnique<BuilderNodeUnfinished>(MakeUnique<BuilderNode>(), LastTransition{bs_ptr[i], Output::Zero()});
stack_.push_back(std::move(unfinished));
}
PushEmpty(true);
Expand All @@ -110,19 +116,20 @@ struct UnfinishedNodes {
}

SizeT FindCommonPrefixAndSetOutput(u8 *bs_ptr, SizeT bs_len, Output &out) {
assert(stack_.size() >= 1);
SizeT i = 0;
for (; i < bs_len; i++) {
if (i >= stack_.size())
break;
SizeT common_len = std::min(bs_len, stack_.size() - 1);
for (; i < common_len; i++) {
auto &t = stack_[i]->last_;
if (!t.has_value() || t->inp_ != bs_ptr[i])
assert(t.has_value());
if (t->inp_ != bs_ptr[i])
break;
Output common_pre = t->out_.Prefix(out);
Output add_prefix = t->out_.Sub(common_pre);
out = out.Sub(common_pre);
t->out_ = common_pre;
if (!add_prefix.IsZero()) {
stack_[i]->AddOutputPrefix(add_prefix);
t->out_ = common_pre;
stack_[i + 1]->AddOutputPrefix(add_prefix);
}
}
return i;
Expand Down
93 changes: 38 additions & 55 deletions src/storage/invertedindex/fst/fst.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,6 @@ public:
/// Returns the address of the root node of this fst.
CompiledAddr RootAddr() { return meta_.root_addr_; }

Optional<Output> EmptyFinalOutput() {
auto root = Root();
if (root->IsFinal()) {
return root->FinalOutput();
}
return {};
}

/// Retrieves the value associated with a key.
///
/// If the key does not exist, then `None` is returned.
Expand All @@ -164,12 +156,12 @@ public:

private:
/// Returns the root node of this fst.
UniquePtr<Node> Root() { return UniquePtr<Node>(new Node(meta_.version_, meta_.root_addr_, data_ptr_, data_len_)); }
UniquePtr<Node> Root() { return MakeUnique<Node>(meta_.version_, meta_.root_addr_, data_ptr_); }

/// Returns the node at the given address.
///
/// Node addresses can be obtained by reading transitions on `Node` values.
UniquePtr<Node> NodeAt(CompiledAddr addr) { return MakeUnique<Node>(meta_.version_, addr, data_ptr_, data_len_); }
UniquePtr<Node> NodeAt(CompiledAddr addr) { return MakeUnique<Node>(meta_.version_, addr, data_ptr_); }
};

export enum BoundType {
Expand Down Expand Up @@ -223,6 +215,42 @@ private:
public:
Stream(Fst &fst, Bound min = Bound(), Bound max = Bound()) : fst_(fst), end_at_(max) { SeekMin(min); }

/// @brief Get next key-value pair per lexicographical order
/// @param key Stores the key of the pair when found
/// @param val Stores the value of the pair when found
/// @return true if found next pair, false if not
bool Next(Vector<u8> &key, u64 &val) {
while (!stack_.empty()) {
StreamState &state = stack_.back();
if (state.trans_ >= state.node_->Len()) {
if (state.node_->Addr() != fst_.RootAddr()) {
inp_.pop_back();
}
stack_.pop_back();
continue;
}
Transition trans = state.node_->TransAt(state.trans_);
Output out = state.out_.Cat(trans.out_);
UniquePtr<Node> next_node = fst_.NodeAt(trans.addr_);
inp_.push_back(trans.inp_);
if (end_at_.ExceededBy(inp_.data(), inp_.size())) {
// We are done, forever.
stack_.clear();
return false;
}
bool is_final = next_node->IsFinal();
if (is_final) {
key = inp_;
val = out.Cat(next_node->FinalOutput()).Value();
}
state.trans_++;
stack_.emplace_back(std::move(next_node), 0, out);
if (is_final)
return true;
}
return false;
}

private:
/// Seeks the underlying stream such that the next key to be read is the
/// smallest key in the underlying fst that satisfies the given minimum
Expand All @@ -233,9 +261,6 @@ private:
/// states.
void SeekMin(Bound min) {
if (min.IsEmpty()) {
if (min.IsInclusive()) {
empty_output_ = fst_.EmptyFinalOutput();
}
stack_.emplace_back(fst_.Root(), 0, Output());
return;
}
Expand Down Expand Up @@ -286,48 +311,6 @@ private:
}
}
}

/// @brief Get next key-value pair per lexicographical order
/// @param key Stores the key of the pair when found
/// @param val Stores the value of the pair when found
/// @return true if found next pair, false if not
bool Next(Vector<u8> &key, Output &val) {
if (empty_output_.has_value()) {
if (end_at_.ExceededBy(nullptr, 0)) {
stack_.clear();
return false;
}
}
while (!stack_.empty()) {
StreamState &state = stack_.back();
if (state.trans_ >= state.node_->Len()) {
stack_.pop_back();
if (state.node_->Addr() != fst_.RootAddr()) {
inp_.pop_back();
}
continue;
}
Transition trans = state.node_->TransAt(state.trans_);
Output out = state.out_.Cat(trans.out_);
UniquePtr<Node> next_node = fst_.NodeAt(trans.addr_);
inp_.push_back(trans.inp_);
if (end_at_.ExceededBy(inp_.data(), inp_.size())) {
// We are done, forever.
stack_.clear();
return false;
}
bool is_final = next_node->IsFinal();
if (is_final) {
key = inp_;
val = out.Cat(next_node->FinalOutput());
}
state.trans_++;
stack_.emplace_back(std::move(next_node), 0, out);
if (is_final)
return true;
}
return false;
}
};

} // namespace infinity
40 changes: 22 additions & 18 deletions src/storage/invertedindex/fst/node.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ struct Node {

/// Creates a new note at the address given.
/// `data` should be a slice to an entire FST.
Node(u64 version, CompiledAddr addr, u8 *data_ptr, SizeT data_len);
Node(u64 version, CompiledAddr addr, u8 *data_ptr);

static void Compile(Writer &wtr, CompiledAddr last_addr, CompiledAddr addr, BuilderNode &node);

Expand All @@ -227,7 +227,7 @@ struct Node {

/// If this node is final and has a terminal output value, then it is
/// returned. Otherwise, a zero output is returned.
Output FinalOutput() { return final_output_; }
Output FinalOutput() { return is_final_ ? final_output_ : Output::Zero(); }

/// Returns true if and only if this node corresponds to a final or "match"
/// state in the finite state transducer.
Expand Down Expand Up @@ -448,7 +448,7 @@ public:
}

Output OutputOf(const Node &node) const {
SizeT osize = node.sizes_.TransitionPackSize();
SizeT osize = node.sizes_.OutputPackSize();
if (osize == 0)
return Output::Zero();
SizeT tsize = node.sizes_.TransitionPackSize();
Expand Down Expand Up @@ -519,10 +519,14 @@ public:
if (node.is_final_) {
PackUintIn(wtr, node.final_output_.Value(), osize);
}
for (int i = int(ntrans) - 1; i >= 0; i--) {
const Transition &t = node.trans_[i];
PackUintIn(wtr, t.out_.Value(), osize);
}
}
for (int i = int(ntrans) - 1; i >= 0; i--) {
const Transition &t = node.trans_[i];
PackUintIn(wtr, t.out_.Value(), osize);
State::PackDeltaIn(wtr, addr, t.addr_, tsize);
}
for (int i = int(ntrans) - 1; i >= 0; i--) {
const Transition &t = node.trans_[i];
Expand Down Expand Up @@ -684,10 +688,9 @@ public:
if (osize == 0) {
return Output::Zero();
}
SizeT at = node.start_ - NtransLen() - 1 // pack size
- TotalTransSize(node.version_, node.sizes_, node.ntrans_) // outputs
- (i * osize) // the previous output values
- osize; // the desired output value
SizeT at = node.start_ - NtransLen() - 1 // pack size
- TotalTransSize(node.version_, node.sizes_, node.ntrans_) - (i * osize) // the previous output values
- osize; // the desired output value
return Output(UnpackUint(node.data_ptr_ + at, osize));
}

Expand Down Expand Up @@ -724,8 +727,8 @@ UniquePtr<State> State::New(u8 *data_ptr, SizeT addr) {
}
}

Node::Node(u64 version, CompiledAddr addr, u8 *data_ptr, SizeT data_len)
: version_(version), data_ptr_(data_ptr), data_len_(data_len), final_output_(Output::Zero()) {
Node::Node(u64 version, CompiledAddr addr, u8 *data_ptr)
: version_(version), data_ptr_(data_ptr), data_len_(addr + 1), final_output_(Output::Zero()) {
if (addr == EMPTY_ADDRESS) {
data_len_ = 0;
state_ = MakeUnique<StateEmptyFinal>();
Expand All @@ -736,34 +739,35 @@ Node::Node(u64 version, CompiledAddr addr, u8 *data_ptr, SizeT data_len)
sizes_ = PackSizes();
return;
}
data_len_ = addr + 1;
start_ = addr;
u8 val = data_ptr[addr];
switch ((val & 0b11000000) >> 6) {
case 0b11: {
auto state = MakeUnique<StateOneTransNext>(val);
end_ = state->EndAddr(data_len);
end_ = state->EndAddr(data_len_);
is_final_ = false;
ntrans_ = 1;
sizes_ = PackSizes();
state_ = std::move(state);
break;
}
case 0b10: {
auto state = MakeUnique<StateOneTrans>(val);
sizes_ = state->Sizes(data_ptr, data_len);
end_ = state->EndAddr(data_len, sizes_);
sizes_ = state->Sizes(data_ptr, data_len_);
end_ = state->EndAddr(data_len_, sizes_);
is_final_ = false;
ntrans_ = 1;
state_ = std::move(state);
break;
}
default: {
auto state = MakeUnique<StateAnyTrans>(val);
SizeT ntrans = state->Ntrans(data_ptr, data_len);
sizes_ = state->Sizes(data_ptr, data_len);
end_ = state->EndAddr(version, data_len, sizes_, ntrans);
SizeT ntrans = state->Ntrans(data_ptr, data_len_);
sizes_ = state->Sizes(data_ptr, data_len_);
end_ = state->EndAddr(version, data_len_, sizes_, ntrans);
is_final_ = state->IsFinalState();
ntrans_ = ntrans;
final_output_ = state->FinalOutput(version, data_ptr, data_len, sizes_, ntrans);
final_output_ = state->FinalOutput(version, data_ptr, data_len_, sizes_, ntrans);
state_ = std::move(state);
}
}
Expand Down
24 changes: 5 additions & 19 deletions src/storage/invertedindex/fst/writer.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import stl;
import crc;
export module fst:writer;

namespace infinity {
export namespace infinity {

class Writer {
public:
Expand All @@ -14,29 +14,15 @@ public:
};

class BufferWriter : public Writer {
private:
u8 *buffer_;
SizeT bufferSize_;
SizeT currentPos_;
public:
Vector<u8> &buffer_;

public:
BufferWriter(u8 *buf, SizeT size) : buffer_(buf), bufferSize_(size), currentPos_(0) {}
BufferWriter(Vector<u8> &buffer) : buffer_(buffer) {}

void Write(const u8 *buf, SizeT size) override {
SizeT remainingSpace = bufferSize_ - currentPos_;
assert(size <= remainingSpace);
SizeT bytesToCopy = std::min(remainingSpace, size);
std::memcpy(buffer_ + currentPos_, buf, bytesToCopy);
currentPos_ += bytesToCopy;
}
void Write(const u8 *data_ptr, SizeT data_size) override { buffer_.insert(buffer_.end(), data_ptr, data_ptr + data_size); }

void Flush() override {}

// Additional method to get the content of the buffer
const u8 *GetBuffer() const { return buffer_; }

// Additional method to get the current position in the buffer
SizeT GetCurrentPosition() const { return currentPos_; }
};

class OstreamWriter : public Writer {
Expand Down
Loading

0 comments on commit 9cea3eb

Please sign in to comment.