Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scripts and fixes for LLM-JP corpus v2 #34

Merged
merged 12 commits into from
Dec 23, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.worksap.nlp.uzushio.lib.filters

import com.worksap.nlp.uzushio.lib.cleaning.{Document, Paragraph}
import com.worksap.nlp.uzushio.lib.filters.base.DocFilter

import scala.collection.mutable.ArrayBuffer

/** This class is a hack put in place before the final bugfix
*/
class AdjacentDuplicateParagraphs extends DocFilter {

private def compressParagraphs(paragraphs: Seq[Paragraph]): Seq[Paragraph] = {
val result = new ArrayBuffer[Paragraph]()
val iter = paragraphs.iterator
if (!iter.hasNext) {
return paragraphs
}

var prev = iter.next()
while (iter.hasNext) {
val next = iter.next()
if (next.text != prev.text) {
result += prev
prev = next
}
}

result += prev
result
}

override def checkDocument(doc: Document): Document = {
val newPars = compressParagraphs(doc.paragraphs)
if (newPars.length == doc.paragraphs.length) {
doc
} else {
doc.copy(paragraphs = newPars)
}
}

override val toString = "AdjacentDuplicateParagraphs"
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package com.worksap.nlp.uzushio.lib.filters
import com.github.jbaiter.kenlm.BufferEvaluator
import com.worksap.nlp.sudachi.{Dictionary, Morpheme}
import com.worksap.nlp.uzushio.lib.cleaning.{Document, Paragraph}
import com.worksap.nlp.uzushio.lib.filters.KenLMEvaluator.logger
import com.worksap.nlp.uzushio.lib.filters.base.{DocFilter, HighLowDocFilter}
import com.worksap.nlp.uzushio.lib.resources.{KenLM, Sudachi}
import com.worksap.nlp.uzushio.lib.utils.Paragraphs
import com.worksap.nlp.uzushio.lib.utils.{Paragraphs, SentenceIterator}
import org.slf4j.LoggerFactory

class KenLMDocAvgPerplexity(
sudachi: String,
Expand Down Expand Up @@ -44,19 +46,26 @@ class KenLMDocAvgPerplexity(
class KenLMEvaluator(sudachi: String, kenlm: String) {
private val dictionary: Dictionary = Sudachi.get(sudachi)
final protected val tokenizer = dictionary.create()
final protected val evaluator = KenLM.get(kenlm).bufferEvaluator(64 * 1024, 1024)
final protected val evaluator = KenLM.get(kenlm).bufferEvaluator(128 * 1024, 1024)

def processParagraph(p: Paragraph): BufferEvaluator = {
val tokens = tokenizer.tokenize(p.text)
val ev = evaluator
val iter = tokens.iterator()
var continue = true
ev.clear()
while (iter.hasNext && continue) {
val token = iter.next()
if (acceptedToken(token)) {
val remaining = ev.append(token.surface())
continue = remaining > 0

val linesIterator = new SentenceIterator(p.text, 16 * 1024)

while (linesIterator.hasNext) {
val line = linesIterator.next()
val tokens = tokenizer.tokenize(line)

val iter = tokens.iterator()
var continue = true
ev.clear()
while (iter.hasNext && continue) {
val token = iter.next()
if (acceptedToken(token)) {
val remaining = ev.append(token.surface())
continue = remaining > 0
}
}
}
ev
Expand All @@ -81,12 +90,21 @@ class KenLMEvaluator(sudachi: String, kenlm: String) {
def extractScore(ev: BufferEvaluator): Double = ev.evaluate()

def scoreParagraph(p: Paragraph): Double = {
val e = processParagraph(p)
val e =
try {
processParagraph(p)
} catch {
case ex: Exception =>
logger.error(s"failed to analyze ${p.text}", ex)
return -50.0
}
extractScore(e)
}
}

object KenLMEvaluator {
final private val logger = LoggerFactory.getLogger(classOf[KenLMEvaluator])

def make(sudachi: String, kenlm: String, ratio: Float): KenLMEvaluator = {
if (ratio < 1e-3) {
new KenLMEvaluator(sudachi, kenlm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package com.worksap.nlp.uzushio.lib.runners

import com.typesafe.config.ConfigFactory
import com.worksap.nlp.uzushio.lib.cleaning.{Document, Paragraph, Pipeline}
import com.worksap.nlp.uzushio.lib.runners.DeduplicateParagraphs.{
cleanParagraphUdf,
splitTextToParagraphs
}
import com.worksap.nlp.uzushio.lib.runners.DuplicateCandidateRow._
import com.worksap.nlp.uzushio.lib.stats.{NgramBitSignatures, NgramHashExtractor, SimHashProcessor}
import com.worksap.nlp.uzushio.lib.utils.Resources.AutoClosableResource
Expand Down Expand Up @@ -372,7 +376,7 @@ class DeduplicateParagraphs(
import com.worksap.nlp.uzushio.lib.utils.BuilderSyntax._
val rawData = spark.read.parquet(args.inputs: _*)

val basicData = prepareBasicData(rawData)
val basicData = prepareDataForDedup(rawData)

val reprParagraphs =
if (args.hasStage("reprHashes")) {
Expand Down Expand Up @@ -422,14 +426,10 @@ class DeduplicateParagraphs(
.option("compression", args.compression).save(args.output)
}

private def prepareBasicData(rawData: DataFrame): DataFrame = {
val cleanParagraphs = udf((x: String) => Paragraphs.extractCleanParagraphs(x))

val splitDocs = rawData.select(
posexplode(cleanParagraphs(rawData.col("text"))).as(Seq("pos", "text"))
)

prepareDataset(splitDocs)
private def prepareDataForDedup(rawData: DataFrame): DataFrame = {
val exploded = splitTextToParagraphs(rawData)
val noMetadata = exploded.withColumn("text", cleanParagraphUdf(exploded.col("text")))
prepareDataset(noMetadata)
}

def prepareDataset(ds: DataFrame): DataFrame = {
Expand Down Expand Up @@ -823,7 +823,7 @@ object DeduplicateParagraphs {

val basicCols = (if (debug) {
joined.columns.filter {
case "parHash" => false
case "parHash" => true
case "exactFreq" | "nearFreq" => false
case _ => true
}
Expand All @@ -845,7 +845,9 @@ object DeduplicateParagraphs {
)
}

private def hashParagraphs(raw: DataFrame) = {
private val cleanParagraphUdf = udf((s: String) => Paragraphs.extractCleanParagraph(s))

private def splitTextToParagraphs(raw: DataFrame) = {
val explodeCols = raw.columns.map {
case "text" => posexplode(split(raw.col("text"), "\n\n")).as(Seq("pos", "text"))
case col => raw.col(col)
Expand All @@ -855,9 +857,12 @@ object DeduplicateParagraphs {
octet_length(raw.col("text")) < 2 * 1024 * 1024 && countParagraphs(raw.col("text")) < 1000
).select(explodeCols: _*)

val cleanParUdf = udf((s: String) => Paragraphs.extractCleanParagraph(s))
exploded
}

exploded.withColumn("parHash", xxhash64(cleanParUdf(exploded.col("text"))))
private def hashParagraphs(raw: DataFrame) = {
val exploded = splitTextToParagraphs(raw)
exploded.withColumn("parHash", xxhash64(cleanParagraphUdf(exploded.col("text"))))
}

def collectDocParts(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.worksap.nlp.uzushio.lib.runners

import com.worksap.nlp.uzushio.lib.cleaning.Paragraph
import com.worksap.nlp.uzushio.lib.filters.KenLMEvaluator
import com.worksap.nlp.uzushio.lib.resources.{KenLM, Sudachi}
import com.worksap.nlp.uzushio.lib.utils.Paragraphs
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{explode, udf}
import org.rogach.scallop.ScallopConf
Expand All @@ -21,27 +23,10 @@ object KenLMRunner {
class LMPerplexity(sudachi: String, kenlm: String) extends Serializable {

@transient
private lazy val tokenizer = Sudachi.get(sudachi).create()

@transient
private lazy val evaluator = KenLM.get(kenlm).bufferEvaluator(64 * 1024, 1024)
private lazy val evaluator = KenLMEvaluator.make(sudachi, kenlm, 0.1f)

def process(par: String): Double = {
val tokens = tokenizer.tokenize(par)
val proc = evaluator

proc.clear()

val iter = tokens.iterator()
var continue = true
while (iter.hasNext && continue) {
val token = iter.next()
if (token.normalizedForm() != " ") {
continue = proc.append(token.surface()) > 0
}
}

val prob = proc.evaluateNoOutliers(0.02f)
val prob = evaluator.scoreParagraph(Paragraph("body", par))
Math.pow(10, -prob)
}

Expand Down Expand Up @@ -69,7 +54,7 @@ object KenLMRunner {
val probs = pars.withColumn("perplexity", ppx.asUdf($"text"))
.repartitionByRange(20, $"perplexity".desc).sortWithinPartitions($"perplexity".desc)

probs.write.json(opts.output())
probs.write.mode(SaveMode.Overwrite).json(opts.output())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.worksap.nlp.uzushio.lib.runners
import com.worksap.nlp.uzushio.lib.utils.Resources.AutoClosableResource
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.rogach.scallop.ScallopConf

object MergeDedupStats {
Expand All @@ -13,6 +13,21 @@ object MergeDedupStats {

val inputData = spark.read.parquet(arg.input(): _*)

val merged = mergeStatisticDatasets(spark, inputData)

val filtered =
if (arg.noOnes()) {
merged
} else merged.where($"nearFreq" > 1)

val partitioned = filtered.repartition(arg.partitions(), $"hash")
.sortWithinPartitions($"reprHash", $"hash")

partitioned.write.option("compression", "zstd").mode(SaveMode.Overwrite).parquet(arg.output())
}

def mergeStatisticDatasets(spark: SparkSession, inputData: DataFrame): DataFrame = {
import spark.implicits._
val clampLongToInt = udf((x: Long) => math.min(x, Int.MaxValue).toInt).asNonNullable()

val combined = inputData.groupBy("hash").agg(
Expand All @@ -26,7 +41,8 @@ object MergeDedupStats {
val remapReprHashes = notUnique.select("reprHashes").select(
array_min($"reprHashes").as("newReprHash"),
explode($"reprHashes").as("oldReprHash")
).where($"newReprHash" =!= $"oldReprHash").distinct()
).where($"newReprHash" =!= $"oldReprHash").groupBy($"oldReprHash")
.agg(min($"newReprHash").as("newReprHash"))

val intermediate = combined.select(
$"hash",
Expand All @@ -47,20 +63,14 @@ object MergeDedupStats {
max($"nearFreq").as("nearFreq")
)

val merged = correctHashes.select("hash", "reprHash", "exactFreq").join(correctFreqs, "reprHash")

val filtered =
if (arg.noOnes()) {
merged
} else merged.where($"nearFreq" > 1)

filtered.write.option("compression", "zstd").mode(SaveMode.Overwrite).parquet(arg.output())
correctHashes.select("hash", "reprHash", "exactFreq").join(correctFreqs, "reprHash")
}

class Args(args: Seq[String]) extends ScallopConf(args) {
val input = opt[List[String]]()
val output = opt[String]()
val master = opt[String]()
val partitions = opt[Int](default = Some(500))
val noOnes = toggle(default = Some(false))
verify()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ object Paragraphs {
final val FULLWIDTH_SPACE: Char = ' '

def extractCleanParagraphs(text: String): Seq[String] = {
val paragraphs = StringUtils.split(text, PARAGRAPH_SEP)
val paragraphs = StringUtils.splitByWholeSeparator(text, PARAGRAPH_SEP)
paragraphs.flatMap { x =>
val par = extractCleanParagraph(x)
if (hasContent(par)) Some(par) else None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.worksap.nlp.uzushio.lib.utils

class SentenceIterator(input: String, maxLength: Int) extends Iterator[String] {

private var start = 0

override def hasNext: Boolean = start < input.length

override def next(): String = {
val curStart = start
var curEnd = SentenceIterator.indexOfSeparator(input, curStart, input.length) match {
case -1 => input.length
case x => x + 1
}

val curLen = curEnd - curStart
if (curLen > maxLength) {
curEnd = curStart + maxLength
}

start = curEnd

input.substring(curStart, curEnd)
}
}

object SentenceIterator {
private val SEPARATORS = "\n。、!?!?".toCharArray

def indexOfSeparator(input: CharSequence, start: Int, end: Int): Int = {
val seps = SEPARATORS
val nseps = seps.length

if (start < 0 || start > input.length()) {
throw new IndexOutOfBoundsException()
}

if (end < 0 || end > input.length()) {
throw new IndexOutOfBoundsException()
}

var i = start
while (i < end) {
val ch = input.charAt(i)
var j = 0
while (j < nseps) {
val ch0 = seps(j)
if (ch == ch0) {
return i
}
j += 1
}
i += 1
}
-1
}
}
Loading