diff --git a/Project.toml b/Project.toml index 023d491..f3f5b5f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TextClassification" uuid = "8e067cb0-742a-4f90-93f9-f1fa01b385ec" authors = ["Eric S. Tellez "] -version = "0.4.2" +version = "0.4.3" [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" diff --git a/src/microtc.jl b/src/microtc.jl index f16136f..0100106 100644 --- a/src/microtc.jl +++ b/src/microtc.jl @@ -10,20 +10,22 @@ export filtered_power_set, predict, predict_corpus, vectorize, vectorize_corpus, import Base: hash, isequal using SparseArrays -struct MicroTC{C_<:MicroTC_Config, CLS_<:Any, TextModel_<:TextModel} +struct MicroTC{C_<:MicroTC_Config, CLS_<:Any, TextModel_<:TextModel, LabelType<:Any} config::C_ cls::CLS_ textmodel::TextModel_ tok::Tokenizer + levels::Vector{LabelType} end StructTypes.StructType(::Type{<:MicroTC}) = StructTypes.Struct() function Base.show(io::IO, model::MicroTC) - print(io, "{MicroTC ") - show(io, model.config) - show(io, model.cls) - show(io, model.textmodel) - show(io, model.tok) + print(io, "{MicroTC") + show(io, ' ', model.config) + show(io, ' ', model.cls) + show(io, ' ', model.textmodel) + show(io, ' ', model.tok) + show(io, ' ', model.levels) print(io, "}") end @@ -86,7 +88,7 @@ function MicroTC( tok=Tokenizer(config.textconfig, invmap=nothing), verbose=true) where {S<:SVEC} cls = create(config.cls, train_X, train_y, textmodel.m) - MicroTC(config, cls, textmodel, tok) + MicroTC(config, cls, textmodel, tok, copy(levels(train_y))) end """ @@ -142,13 +144,13 @@ end Predicts the label of the given input """ predict(tc::MicroTC, text) = predict(tc.cls, vectorize(tc, text)) -predict(tc::MicroTC, vec::SVEC) = predict(tc.cls, vec) +predict(tc::MicroTC, vec::SVEC) = tc.levels[predict(tc.cls, vec)] function predict_corpus(tc::MicroTC, corpus; bow=BOW(), tok=tc.tok, normalize=true) - V = Vector{UInt32}(undef, length(corpus)) + V = Vector{eltype(tc.levels)}(undef, length(corpus)) for i in eachindex(corpus) empty!(bow) diff --git a/test/runtests.jl b/test/runtests.jl index 90e3586..6cbdd96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,7 +37,7 @@ end #valX = vectorize_corpus(tc, _testcorpus) #ypred = predict.(tc, valX) ypred = predict_corpus(tc, _testcorpus) - push!(S, recall_score(_testlabels.refs, ypred, weight=:macro)) + push!(S, recall_score(_testlabels, ypred, weight=:macro)) end 1.0 - mean(S) @@ -75,12 +75,12 @@ end end cls = MicroTC(best_list[1][1], traincorpus, trainlabels) - sc = classification_scores(testlabels.refs, predict_corpus(cls, testcorpus)) + sc = classification_scores(testlabels, predict_corpus(cls, testcorpus)) @info "*** Performance on test: " sc @test sc.accuracy > 0.6 cls_ = JSON3.read(JSON3.write(cls), typeof(cls)) - sc = classification_scores(testlabels.refs, predict_corpus(cls, testcorpus)) + sc = classification_scores(testlabels, predict_corpus(cls, testcorpus)) @info "*** Performance on test: " sc @test sc.accuracy > 0.6