| /* |
| * Cryptographic API. |
| * |
| * Copyright (c) 2017-present, Facebook, Inc. |
| * |
| * This program is free software; you can redistribute it and/or modify it |
| * under the terms of the GNU General Public License version 2 as published by |
| * the Free Software Foundation. |
| * |
| * This program is distributed in the hope that it will be useful, but WITHOUT |
| * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
| * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for |
| * more details. |
| */ |
| #include <linux/crypto.h> |
| #include <linux/init.h> |
| #include <linux/interrupt.h> |
| #include <linux/mm.h> |
| #include <linux/module.h> |
| #include <linux/net.h> |
| #include <linux/vmalloc.h> |
| #include <linux/zstd.h> |
| #include <crypto/internal/scompress.h> |
| |
| |
| #define ZSTD_DEF_LEVEL 1 |
| |
| struct zstd_ctx { |
| ZSTD_CCtx *cctx; |
| ZSTD_DCtx *dctx; |
| void *cwksp; |
| void *dwksp; |
| }; |
| |
| static ZSTD_parameters zstd_params(void) |
| { |
| return ZSTD_getParams(ZSTD_DEF_LEVEL, PAGE_SIZE, 0); |
| } |
| |
| static int zstd_comp_init(struct zstd_ctx *ctx) |
| { |
| int ret = 0; |
| const ZSTD_parameters params = zstd_params(); |
| const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams); |
| |
| ctx->cwksp = vzalloc(wksp_size); |
| if (!ctx->cwksp) { |
| ret = -ENOMEM; |
| goto out; |
| } |
| |
| ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size); |
| if (!ctx->cctx) { |
| ret = -EINVAL; |
| goto out_free; |
| } |
| out: |
| return ret; |
| out_free: |
| vfree(ctx->cwksp); |
| goto out; |
| } |
| |
| static int zstd_decomp_init(struct zstd_ctx *ctx) |
| { |
| int ret = 0; |
| const size_t wksp_size = ZSTD_DCtxWorkspaceBound(); |
| |
| ctx->dwksp = vzalloc(wksp_size); |
| if (!ctx->dwksp) { |
| ret = -ENOMEM; |
| goto out; |
| } |
| |
| ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size); |
| if (!ctx->dctx) { |
| ret = -EINVAL; |
| goto out_free; |
| } |
| out: |
| return ret; |
| out_free: |
| vfree(ctx->dwksp); |
| goto out; |
| } |
| |
| static void zstd_comp_exit(struct zstd_ctx *ctx) |
| { |
| vfree(ctx->cwksp); |
| ctx->cwksp = NULL; |
| ctx->cctx = NULL; |
| } |
| |
| static void zstd_decomp_exit(struct zstd_ctx *ctx) |
| { |
| vfree(ctx->dwksp); |
| ctx->dwksp = NULL; |
| ctx->dctx = NULL; |
| } |
| |
| static int __zstd_init(void *ctx) |
| { |
| int ret; |
| |
| ret = zstd_comp_init(ctx); |
| if (ret) |
| return ret; |
| ret = zstd_decomp_init(ctx); |
| if (ret) |
| zstd_comp_exit(ctx); |
| return ret; |
| } |
| |
| static void *zstd_alloc_ctx(struct crypto_scomp *tfm) |
| { |
| int ret; |
| struct zstd_ctx *ctx; |
| |
| ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); |
| if (!ctx) |
| return ERR_PTR(-ENOMEM); |
| |
| ret = __zstd_init(ctx); |
| if (ret) { |
| kfree(ctx); |
| return ERR_PTR(ret); |
| } |
| |
| return ctx; |
| } |
| |
| static int zstd_init(struct crypto_tfm *tfm) |
| { |
| struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
| |
| return __zstd_init(ctx); |
| } |
| |
| static void __zstd_exit(void *ctx) |
| { |
| zstd_comp_exit(ctx); |
| zstd_decomp_exit(ctx); |
| } |
| |
| static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx) |
| { |
| __zstd_exit(ctx); |
| kzfree(ctx); |
| } |
| |
| static void zstd_exit(struct crypto_tfm *tfm) |
| { |
| struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
| |
| __zstd_exit(ctx); |
| } |
| |
| static int __zstd_compress(const u8 *src, unsigned int slen, |
| u8 *dst, unsigned int *dlen, void *ctx) |
| { |
| size_t out_len; |
| struct zstd_ctx *zctx = ctx; |
| const ZSTD_parameters params = zstd_params(); |
| |
| out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params); |
| if (ZSTD_isError(out_len)) |
| return -EINVAL; |
| *dlen = out_len; |
| return 0; |
| } |
| |
| static int zstd_compress(struct crypto_tfm *tfm, const u8 *src, |
| unsigned int slen, u8 *dst, unsigned int *dlen) |
| { |
| struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
| |
| return __zstd_compress(src, slen, dst, dlen, ctx); |
| } |
| |
| static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src, |
| unsigned int slen, u8 *dst, unsigned int *dlen, |
| void *ctx) |
| { |
| return __zstd_compress(src, slen, dst, dlen, ctx); |
| } |
| |
| static int __zstd_decompress(const u8 *src, unsigned int slen, |
| u8 *dst, unsigned int *dlen, void *ctx) |
| { |
| size_t out_len; |
| struct zstd_ctx *zctx = ctx; |
| |
| out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen); |
| if (ZSTD_isError(out_len)) |
| return -EINVAL; |
| *dlen = out_len; |
| return 0; |
| } |
| |
| static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src, |
| unsigned int slen, u8 *dst, unsigned int *dlen) |
| { |
| struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); |
| |
| return __zstd_decompress(src, slen, dst, dlen, ctx); |
| } |
| |
| static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src, |
| unsigned int slen, u8 *dst, unsigned int *dlen, |
| void *ctx) |
| { |
| return __zstd_decompress(src, slen, dst, dlen, ctx); |
| } |
| |
| static struct crypto_alg alg = { |
| .cra_name = "zstd", |
| .cra_flags = CRYPTO_ALG_TYPE_COMPRESS, |
| .cra_ctxsize = sizeof(struct zstd_ctx), |
| .cra_module = THIS_MODULE, |
| .cra_init = zstd_init, |
| .cra_exit = zstd_exit, |
| .cra_u = { .compress = { |
| .coa_compress = zstd_compress, |
| .coa_decompress = zstd_decompress } } |
| }; |
| |
| static struct scomp_alg scomp = { |
| .alloc_ctx = zstd_alloc_ctx, |
| .free_ctx = zstd_free_ctx, |
| .compress = zstd_scompress, |
| .decompress = zstd_sdecompress, |
| .base = { |
| .cra_name = "zstd", |
| .cra_driver_name = "zstd-scomp", |
| .cra_module = THIS_MODULE, |
| } |
| }; |
| |
| static int __init zstd_mod_init(void) |
| { |
| int ret; |
| |
| ret = crypto_register_alg(&alg); |
| if (ret) |
| return ret; |
| |
| ret = crypto_register_scomp(&scomp); |
| if (ret) |
| crypto_unregister_alg(&alg); |
| |
| return ret; |
| } |
| |
| static void __exit zstd_mod_fini(void) |
| { |
| crypto_unregister_alg(&alg); |
| crypto_unregister_scomp(&scomp); |
| } |
| |
| module_init(zstd_mod_init); |
| module_exit(zstd_mod_fini); |
| |
| MODULE_LICENSE("GPL"); |
| MODULE_DESCRIPTION("Zstd Compression Algorithm"); |
| MODULE_ALIAS_CRYPTO("zstd"); |