diff --git a/Main.cs b/Main.cs index 663c729..cde3e3b 100644 --- a/Main.cs +++ b/Main.cs @@ -120,7 +120,7 @@ public void Append(IEnumerable values) } } -class CommandBuilder +static class CommandBuilder { public static string Raw(string command) { @@ -251,9 +251,12 @@ public static IEnumerable WriteToFile(string path, IEnumerable l return lines.SelectMany(line => WriteToFile(path, line)); } - public static IEnumerable SafeWriteToFile(string path, string content) + public static IEnumerable SafeWriteToFile(string path, string content, Encoding encoding) { - byte[] bytes = Encoding.UTF8.GetBytes(content); + byte[] bytes = Enumerable.Concat( + encoding.GetPreamble(), + encoding.GetBytes(content) + ).ToArray(); int chunkSize = 256 - 70; foreach (string base64 in Convert.ToBase64String(bytes).Chunk(chunkSize).Select(chars => new string(chars))) diff --git a/modifier/Script.cs b/modifier/Script.cs index ee77eb7..62c07fe 100644 --- a/modifier/Script.cs +++ b/modifier/Script.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text; namespace Schneegans.Unattend; @@ -24,6 +25,25 @@ public static IEnumerable GetAllowedTypes(this ScriptPhase phase) _ => Enum.GetValues(), }; } + + public static string FileExtension(this ScriptType type) + { + return '.' + type.ToString().ToLowerInvariant(); + } + + public static Encoding PreferredEncoding(this ScriptType type) + { + UnicodeEncoding utf16WithBom = new(bigEndian: false, byteOrderMark: true); + return type switch + { + ScriptType.Ps1 => Encoding.UTF8, + ScriptType.Cmd => Encoding.Latin1, + ScriptType.Reg => utf16WithBom, + ScriptType.Vbs => utf16WithBom, + ScriptType.Js => utf16WithBom, + _ => throw new NotImplementedException(), + }; + } } public record class ScriptSettings( @@ -39,6 +59,15 @@ public Script(string content, ScriptPhase phase, ScriptType type) throw new ConfigurationException($"Scripts in phase '{phase}' must not have type '{type}'."); } + if (phase == ScriptPhase.DefaultUser && type == ScriptType.Reg && !string.IsNullOrWhiteSpace(content)) + { + string prefix = @"[HKEY_USERS\DefaultUser\"; + if (!content.Contains(prefix, StringComparison.OrdinalIgnoreCase)) + { + throw new ConfigurationException($"{type.FileExtension()} script '{content}' does not contain required key prefix '{prefix}'."); + } + } + Content = content; Phase = phase; Type = type; @@ -111,7 +140,7 @@ static string Clean(Script script) var appender = new CommandAppender(Document, NamespaceManager, CommandConfig.Specialize); appender.Append( - CommandBuilder.SafeWriteToFile(scriptId.FullName, Clean(script)) + CommandBuilder.SafeWriteToFile(scriptId.FullName, Clean(script), script.Type.PreferredEncoding()) ); }