Skip to content

Commit

Permalink
adds labels map into MicroTC struct; increases version
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Apr 27, 2021
1 parent 93d6e54 commit 76f3404
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TextClassification"
uuid = "8e067cb0-742a-4f90-93f9-f1fa01b385ec"
authors = ["Eric S. Tellez <donsadit@gmail.com>"]
version = "0.4.2"
version = "0.4.3"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand Down
20 changes: 11 additions & 9 deletions src/microtc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ export filtered_power_set, predict, predict_corpus, vectorize, vectorize_corpus,
import Base: hash, isequal
using SparseArrays

struct MicroTC{C_<:MicroTC_Config, CLS_<:Any, TextModel_<:TextModel}
struct MicroTC{C_<:MicroTC_Config, CLS_<:Any, TextModel_<:TextModel, LabelType<:Any}
config::C_
cls::CLS_
textmodel::TextModel_
tok::Tokenizer
levels::Vector{LabelType}
end

StructTypes.StructType(::Type{<:MicroTC}) = StructTypes.Struct()
function Base.show(io::IO, model::MicroTC)
print(io, "{MicroTC ")
show(io, model.config)
show(io, model.cls)
show(io, model.textmodel)
show(io, model.tok)
print(io, "{MicroTC")
show(io, ' ', model.config)
show(io, ' ', model.cls)
show(io, ' ', model.textmodel)
show(io, ' ', model.tok)
show(io, ' ', model.levels)
print(io, "}")
end

Expand Down Expand Up @@ -86,7 +88,7 @@ function MicroTC(
tok=Tokenizer(config.textconfig, invmap=nothing),
verbose=true) where {S<:SVEC}
cls = create(config.cls, train_X, train_y, textmodel.m)
MicroTC(config, cls, textmodel, tok)
MicroTC(config, cls, textmodel, tok, copy(levels(train_y)))
end

"""
Expand Down Expand Up @@ -142,13 +144,13 @@ end
Predicts the label of the given input
"""
predict(tc::MicroTC, text) = predict(tc.cls, vectorize(tc, text))
predict(tc::MicroTC, vec::SVEC) = predict(tc.cls, vec)
predict(tc::MicroTC, vec::SVEC) = tc.levels[predict(tc.cls, vec)]

function predict_corpus(tc::MicroTC, corpus;
bow=BOW(),
tok=tc.tok,
normalize=true)
V = Vector{UInt32}(undef, length(corpus))
V = Vector{eltype(tc.levels)}(undef, length(corpus))

for i in eachindex(corpus)
empty!(bow)
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
#valX = vectorize_corpus(tc, _testcorpus)
#ypred = predict.(tc, valX)
ypred = predict_corpus(tc, _testcorpus)
push!(S, recall_score(_testlabels.refs, ypred, weight=:macro))
push!(S, recall_score(_testlabels, ypred, weight=:macro))
end

1.0 - mean(S)
Expand Down Expand Up @@ -75,12 +75,12 @@ end
end

cls = MicroTC(best_list[1][1], traincorpus, trainlabels)
sc = classification_scores(testlabels.refs, predict_corpus(cls, testcorpus))
sc = classification_scores(testlabels, predict_corpus(cls, testcorpus))
@info "*** Performance on test: " sc
@test sc.accuracy > 0.6

cls_ = JSON3.read(JSON3.write(cls), typeof(cls))
sc = classification_scores(testlabels.refs, predict_corpus(cls, testcorpus))
sc = classification_scores(testlabels, predict_corpus(cls, testcorpus))
@info "*** Performance on test: " sc
@test sc.accuracy > 0.6

Expand Down

2 comments on commit 76f3404

@sadit
Copy link
Owner Author

@sadit sadit commented on 76f3404 Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/35520

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.3 -m "<description of version>" 76f3404c2bc649697420166c9b4320d662b1470e
git push origin v0.4.3

Please sign in to comment.